From 502bedc137057c63d71c7f459c1fcb353c747575 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 21:28:40 +0200 Subject: [PATCH 01/23] Refactor prompts: local types and builder Replace external SDK prompt models with local PromptContext/PromptMessage/PromptBlock types and remove reliance on bridgesdk. Add prompt_builder, prompt_context_local and prompt_projection_local to centralize prompt assembly, history replay, and current-turn text construction (including deterministic handling of untrusted prefixes and link context). Consolidate history loading into replayHistoryMessages and fetchHistoryRowsWithExtra, update media handling (PDFs normalized into file-context text; audio/video must be preprocessed), and trim/normalize system prompt text passed to Responses. Remove ensureCanonicalUserMessage and canonical_user_messages.go, update many call sites and tests, and adjust default command prefix resolution to use ResolveCommandPrefix. This simplifies prompt lifecycle, reduces cross-package coupling, and standardizes multimodal input conversion. --- bridges/ai/agent_loop_request_builders.go | 8 +- .../ai/agent_loop_request_builders_test.go | 4 +- bridges/ai/agent_loop_routing_test.go | 10 +- bridges/ai/canonical_prompt_messages.go | 20 +- bridges/ai/canonical_user_messages.go | 25 - bridges/ai/client.go | 214 ++----- bridges/ai/constructors.go | 2 +- bridges/ai/constructors_test.go | 3 + bridges/ai/handlematrix.go | 100 +--- bridges/ai/heartbeat_execute.go | 4 +- bridges/ai/identifiers.go | 6 +- bridges/ai/image_understanding.go | 80 +-- bridges/ai/internal_dispatch.go | 1 - bridges/ai/media_understanding_runner.go | 38 +- bridges/ai/messages.go | 68 ++- bridges/ai/messages_responses_input_test.go | 19 +- bridges/ai/prompt_builder.go | 178 ++++++ bridges/ai/prompt_context_local.go | 371 ++++++++++++ bridges/ai/prompt_projection_local.go | 166 ++++++ bridges/ai/provider_openai_chat.go | 4 +- bridges/ai/provider_openai_responses.go | 8 +- bridges/ai/provider_openai_responses_test.go | 32 +- bridges/ai/response_retry.go | 3 +- bridges/ai/session_greeting.go | 21 +- bridges/ai/streaming_chat_completions.go | 4 - bridges/ai/streaming_continuation.go | 8 +- bridges/ai/streaming_input_conversion.go | 7 - bridges/ai/streaming_responses_api.go | 9 +- bridges/ai/streaming_responses_input_test.go | 36 +- bridges/ai/streaming_state.go | 1 + bridges/ai/subagent_announce.go | 4 +- bridges/ai/subagent_spawn.go | 1 - bridges/ai/system_prompts.go | 61 +- bridges/ai/text_files.go | 41 ++ bridges/ai/tools_analyze_image.go | 5 +- connector_builder.go | 8 +- connector_builder_test.go | 17 +- sdk/connector.go | 12 +- sdk/connector_helpers.go | 9 + sdk/connector_hooks_test.go | 22 +- sdk/prompt_context.go | 557 ------------------ sdk/prompt_projection.go | 284 --------- sdk/turn_data_test.go | 63 -- sdk/turn_snapshot.go | 2 - 44 files changed, 1050 insertions(+), 1486 deletions(-) delete mode 100644 bridges/ai/canonical_user_messages.go create mode 100644 bridges/ai/prompt_builder.go create mode 100644 bridges/ai/prompt_context_local.go create mode 100644 bridges/ai/prompt_projection_local.go delete mode 100644 sdk/prompt_context.go delete mode 100644 sdk/prompt_projection.go diff --git a/bridges/ai/agent_loop_request_builders.go b/bridges/ai/agent_loop_request_builders.go index 3a84b747..a6e2be3a 100644 --- a/bridges/ai/agent_loop_request_builders.go +++ b/bridges/ai/agent_loop_request_builders.go @@ -2,6 +2,7 @@ package ai import ( "context" + "strings" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/packages/param" @@ -16,7 +17,6 @@ type agentLoopRequestSettings struct { model string maxTokens int temperature *float64 - systemPrompt string reasoningEffort string } @@ -25,7 +25,6 @@ func (oc *AIClient) buildAgentLoopRequestSettings(meta *PortalMetadata) agentLoo model: oc.effectiveModelForAPI(meta), maxTokens: oc.effectiveMaxTokens(meta), temperature: oc.effectiveTemperature(meta), - systemPrompt: oc.effectivePrompt(meta), reasoningEffort: oc.effectiveReasoningEffort(meta), } } @@ -101,6 +100,7 @@ func (oc *AIClient) buildChatCompletionsAgentLoopParams( func (oc *AIClient) buildResponsesAgentLoopParams( ctx context.Context, meta *PortalMetadata, + systemPrompt string, input responses.ResponseInputParam, allowResolvedBossAgent bool, ) responses.ResponseNewParams { @@ -119,8 +119,8 @@ func (oc *AIClient) buildResponsesAgentLoopParams( if settings.temperature != nil { params.Temperature = openai.Float(*settings.temperature) } - if settings.systemPrompt != "" { - params.Instructions = openai.String(settings.systemPrompt) + if strings.TrimSpace(systemPrompt) != "" { + params.Instructions = openai.String(strings.TrimSpace(systemPrompt)) } if effort, ok := reasoningEffortMap[settings.reasoningEffort]; ok { params.Reasoning = shared.ReasoningParam{ diff --git a/bridges/ai/agent_loop_request_builders_test.go b/bridges/ai/agent_loop_request_builders_test.go index 41caf0aa..20e6de3a 100644 --- a/bridges/ai/agent_loop_request_builders_test.go +++ b/bridges/ai/agent_loop_request_builders_test.go @@ -37,7 +37,7 @@ func TestAgentLoopRequestBuildersShareModelAndTokenSettings(t *testing.T) { chatParams := oc.buildChatCompletionsAgentLoopParams(context.Background(), meta, []openai.ChatCompletionMessageParamUnion{ openai.UserMessage("hello"), }) - responsesParams := oc.buildResponsesAgentLoopParams(context.Background(), meta, nil, false) + responsesParams := oc.buildResponsesAgentLoopParams(context.Background(), meta, "system prompt", nil, false) if chatParams.Model != "openai/gpt-5.2" { t.Fatalf("expected chat model openai/gpt-5.2, got %q", chatParams.Model) @@ -96,7 +96,7 @@ func TestAgentLoopRequestBuildersPreserveExplicitZeroTemperature(t *testing.T) { chatParams := oc.buildChatCompletionsAgentLoopParams(context.Background(), meta, []openai.ChatCompletionMessageParamUnion{ openai.UserMessage("hello"), }) - responsesParams := oc.buildResponsesAgentLoopParams(context.Background(), meta, nil, false) + responsesParams := oc.buildResponsesAgentLoopParams(context.Background(), meta, "system prompt", nil, false) if !chatParams.Temperature.Valid() || chatParams.Temperature.Value != 0 { t.Fatalf("expected explicit zero chat temperature, got %#v", chatParams.Temperature) diff --git a/bridges/ai/agent_loop_routing_test.go b/bridges/ai/agent_loop_routing_test.go index 2b92eda0..57bd284a 100644 --- a/bridges/ai/agent_loop_routing_test.go +++ b/bridges/ai/agent_loop_routing_test.go @@ -7,8 +7,6 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - - bridgesdk "github.com/beeper/agentremote/sdk" ) func newAgentLoopRoutingTestClient(models ...ModelInfo) *AIClient { @@ -45,13 +43,7 @@ func TestSelectAgentLoopRunFunc_UsesChatCompletionsForUnsupportedResponsesPrompt API: string(ModelAPIResponses), }) - promptContext := PromptContext{ - PromptContext: bridgesdk.UserPromptContext(bridgesdk.PromptBlock{ - Type: bridgesdk.PromptBlockAudio, - AudioB64: "YXVkaW8=", - AudioFormat: "mp3", - }), - } + promptContext := UserPromptContext(PromptBlock{Type: PromptBlockType("unknown")}) responseFn, logLabel := oc.selectAgentLoopRunFunc(resolvedModelMeta("openai/gpt-4.1"), promptContext) if responseFn == nil { diff --git a/bridges/ai/canonical_prompt_messages.go b/bridges/ai/canonical_prompt_messages.go index db8c578f..e263627f 100644 --- a/bridges/ai/canonical_prompt_messages.go +++ b/bridges/ai/canonical_prompt_messages.go @@ -2,13 +2,11 @@ package ai import ( "strings" - - "github.com/beeper/agentremote/sdk" ) func promptMessagesFromMetadata(meta *MessageMetadata) []PromptMessage { if turnData, ok := canonicalTurnData(meta); ok { - return sdk.PromptMessagesFromTurnData(turnData) + return promptMessagesFromTurnData(turnData) } return nil } @@ -50,20 +48,6 @@ func filterPromptBlocksForHistory(blocks []PromptBlock, injectImages bool) []Pro return filtered } -func textPromptMessage(text string) []PromptMessage { - text = strings.TrimSpace(text) - if text == "" { - return nil - } - return []PromptMessage{{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: text, - }}, - }} -} - func promptTail(ctx PromptContext, count int) []PromptMessage { if count <= 0 || len(ctx.Messages) == 0 { return nil @@ -80,7 +64,7 @@ func setCanonicalTurnDataFromPromptMessages(meta *MessageMetadata, messages []Pr if meta == nil || len(messages) == 0 { return } - if turnData, ok := sdk.TurnDataFromUserPromptMessages(messages); ok { + if turnData, ok := turnDataFromUserPromptMessages(messages); ok { meta.CanonicalTurnData = turnData.ToMap() } else { meta.CanonicalTurnData = nil diff --git a/bridges/ai/canonical_user_messages.go b/bridges/ai/canonical_user_messages.go deleted file mode 100644 index 7a84b89a..00000000 --- a/bridges/ai/canonical_user_messages.go +++ /dev/null @@ -1,25 +0,0 @@ -package ai - -import ( - "strings" - - "maunium.net/go/mautrix/bridgev2/database" -) - -func ensureCanonicalUserMessage(msg *database.Message) { - if msg == nil { - return - } - meta, ok := msg.Metadata.(*MessageMetadata) - if !ok || meta == nil || strings.TrimSpace(meta.Role) != "user" { - return - } - if len(meta.CanonicalTurnData) > 0 { - return - } - - body := strings.TrimSpace(meta.Body) - if body != "" { - setCanonicalTurnDataFromPromptMessages(meta, textPromptMessage(body)) - } -} diff --git a/bridges/ai/client.go b/bridges/ai/client.go index c6faf97f..2001fdc8 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -28,7 +28,6 @@ import ( "github.com/beeper/agentremote/pkg/agents" airuntime "github.com/beeper/agentremote/pkg/runtime" "github.com/beeper/agentremote/pkg/shared/stringutil" - bridgesdk "github.com/beeper/agentremote/sdk" ) var ( @@ -604,7 +603,6 @@ func (oc *AIClient) saveUserMessage(ctx context.Context, evt *event.Event, msg * if evt != nil { msg.MXID = evt.ID } - ensureCanonicalUserMessage(msg) if _, err := oc.UserLogin.Bridge.GetGhostByID(ctx, msg.SenderID); err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure user ghost before saving message") } @@ -1669,7 +1667,7 @@ func (oc *AIClient) promptContextToDispatchMessages( meta *PortalMetadata, promptContext PromptContext, ) []openai.ChatCompletionMessageParamUnion { - promptMessages := bridgesdk.PromptContextToChatCompletionMessages(promptContext.PromptContext, oc.isOpenRouterProvider()) + promptMessages := PromptContextToChatCompletionMessages(promptContext, oc.isOpenRouterProvider()) promptMessages = oc.augmentPromptWithIntegrations(ctx, portal, meta, promptMessages) if meta != nil && IsGoogleModel(oc.effectiveModel(meta)) { promptMessages = SanitizeGoogleTurnOrdering(promptMessages) @@ -1683,57 +1681,12 @@ type historyLoadResult struct { resetAt int64 } -// fetchHistoryRows resolves the history limit, extracts the resetAt cutoff, -// and fetches the last N messages. Returns nil when the history limit is zero. -func (oc *AIClient) fetchHistoryRows( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, -) (*historyLoadResult, error) { - historyLimit := oc.historyLimit(ctx, portal, meta) - if historyLimit <= 0 { - return nil, nil - } - resetAt := int64(0) - if meta != nil { - resetAt = meta.SessionResetAt - } - history, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(ctx, portal.PortalKey, historyLimit) - if err != nil { - return nil, fmt.Errorf("failed to load prompt history: %w", err) - } - return &historyLoadResult{ - rows: history, - hasVision: oc.getModelCapabilitiesForMeta(ctx, meta).SupportsVision, - resetAt: resetAt, - }, nil -} - func (oc *AIClient) loadHistoryMessages( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, ) ([]PromptMessage, error) { - hr, err := oc.fetchHistoryRows(ctx, portal, meta) - if err != nil { - return nil, err - } - if hr == nil { - return nil, nil - } - var messages []PromptMessage - for i := len(hr.rows) - 1; i >= 0; i-- { - msgMeta := messageMeta(hr.rows[i]) - if !shouldIncludeInHistory(msgMeta) { - continue - } - if hr.resetAt > 0 && hr.rows[i].Timestamp.UnixMilli() < hr.resetAt { - continue - } - injectImages := hr.hasVision && i < maxHistoryImageMessages - messages = append(messages, oc.historyMessageBundle(ctx, msgMeta, injectImages)...) - } - return messages, nil + return oc.replayHistoryMessages(ctx, portal, meta, historyReplayOptions{mode: historyReplayNormal}) } func (oc *AIClient) buildBaseContext( @@ -1741,9 +1694,9 @@ func (oc *AIClient) buildBaseContext( portal *bridgev2.Portal, meta *PortalMetadata, ) (PromptContext, error) { - var promptContext PromptContext - bridgesdk.AppendChatMessagesToPromptContext(&promptContext.PromptContext, maybePrependSessionGreeting(ctx, portal, meta, nil, oc.log)) - bridgesdk.AppendChatMessagesToPromptContext(&promptContext.PromptContext, oc.buildSystemMessages(ctx, portal, meta)) + promptContext := PromptContext{ + SystemPrompt: oc.buildConversationSystemPromptText(ctx, portal, meta, true), + } historyMessages, err := oc.loadHistoryMessages(ctx, portal, meta) if err != nil { @@ -1788,8 +1741,9 @@ type inboundPromptResult struct { } // prepareInboundPromptContext builds the base context, resolves inbound context, -// appends the meta system prompt, resolves body overrides, and applies the abort hint. -// Callers must call applyUntrustedPrefix at the appropriate point in message assembly. +// appends trusted inbound metadata to the system prompt, resolves body overrides, +// and applies the abort hint. Untrusted inbound prefixes are returned separately +// so callers can place them deterministically in the user prompt body. func (oc *AIClient) prepareInboundPromptContext( ctx context.Context, portal *bridgev2.Portal, @@ -1802,7 +1756,7 @@ func (oc *AIClient) prepareInboundPromptContext( return inboundPromptResult{}, err } inboundCtx := oc.resolvePromptInboundContext(ctx, portal, userText, eventID) - bridgesdk.AppendPromptText(&promptContext.SystemPrompt, airuntime.BuildInboundMetaSystemPrompt(inboundCtx)) + AppendPromptText(&promptContext.SystemPrompt, airuntime.BuildInboundMetaSystemPrompt(inboundCtx)) resolved := strings.TrimSpace(userText) if body := strings.TrimSpace(inboundCtx.BodyForAgent); body != "" { @@ -1818,13 +1772,6 @@ func (oc *AIClient) prepareInboundPromptContext( UntrustedPrefix: untrustedPrefix, }, nil } - -func (r *inboundPromptResult) applyUntrustedPrefix() { - if r.UntrustedPrefix != "" { - r.ResolvedBody = r.UntrustedPrefix + "\n\n" + r.ResolvedBody - } -} - func (oc *AIClient) buildContextWithLinkContext( ctx context.Context, portal *bridgev2.Portal, @@ -1833,34 +1780,21 @@ func (oc *AIClient) buildContextWithLinkContext( rawEventContent map[string]any, eventID id.EventID, ) (PromptContext, error) { - result, err := oc.prepareInboundPromptContext(ctx, portal, meta, latest, eventID) + promptContext, text, err := oc.buildCurrentTurnText(ctx, portal, meta, latest, eventID, currentTurnTextOptions{ + rawEventContent: rawEventContent, + includeLinkScope: true, + }) if err != nil { return PromptContext{}, err } - - if linkContext := oc.buildLinkContext(ctx, latest, rawEventContent); linkContext != "" { - result.ResolvedBody += linkContext - } - - if portal != nil && portal.MXID != "" { - reactionFeedback := DrainReactionFeedback(portal.MXID) - if len(reactionFeedback) > 0 { - if feedbackText := FormatReactionFeedback(reactionFeedback); feedbackText != "" { - result.ResolvedBody = feedbackText + "\n" + result.ResolvedBody - } - } - } - - result.applyUntrustedPrefix() - - result.PromptContext.Messages = append(result.PromptContext.Messages, PromptMessage{ + promptContext.Messages = append(promptContext.Messages, PromptMessage{ Role: PromptRoleUser, Blocks: []PromptBlock{{ Type: PromptBlockText, - Text: result.ResolvedBody, + Text: text, }}, }) - return result.PromptContext, nil + return promptContext, nil } // buildLinkContext extracts URLs from the message, fetches previews, and returns formatted context. @@ -1935,15 +1869,8 @@ func (oc *AIClient) buildContextWithMedia( mediaType pendingMessageType, eventID id.EventID, ) (PromptContext, error) { - result, err := oc.prepareInboundPromptContext(ctx, portal, meta, caption, eventID) - if err != nil { - return PromptContext{}, err - } - result.applyUntrustedPrefix() + appendBlocks := make([]string, 0, 1) blocks := make([]PromptBlock, 0, 2) - if strings.TrimSpace(result.ResolvedBody) != "" { - blocks = append(blocks, PromptBlock{Type: PromptBlockText, Text: result.ResolvedBody}) - } switch mediaType { case pendingTypeImage: @@ -1958,44 +1885,38 @@ func (oc *AIClient) buildContextWithMedia( }) case pendingTypePDF: - b64Data, actualMimeType, err := oc.downloadMediaBase64(ctx, mediaURL, encryptedFile, 50, mimeType) // 50MB limit + content, truncated, err := oc.downloadPDFFile(ctx, mediaURL, encryptedFile, mimeType) if err != nil { return PromptContext{}, fmt.Errorf("failed to download PDF: %w", err) } - if actualMimeType == "" { - actualMimeType = "application/pdf" - } - blocks = append(blocks, PromptBlock{ - Type: PromptBlockFile, - FileB64: bridgesdk.BuildDataURL(actualMimeType, b64Data), - Filename: "document.pdf", - MimeType: actualMimeType, - }) + filename := resolveMediaFileName("document.pdf", "pdf", mediaURL) + appendBlocks = append(appendBlocks, buildTextFileMessage("", false, filename, "application/pdf", content, truncated)) case pendingTypeAudio: - if strings.TrimSpace(result.ResolvedBody) == "" { - blocks = append(blocks, PromptBlock{ - Type: PromptBlockText, - Text: fmt.Sprintf("Audio attachment: %s", mediaURL), - }) - } + return PromptContext{}, fmt.Errorf("audio attachments must be preprocessed into text before prompt assembly") case pendingTypeVideo: - if strings.TrimSpace(result.ResolvedBody) == "" { - blocks = append(blocks, PromptBlock{ - Type: PromptBlockText, - Text: fmt.Sprintf("Video attachment: %s", mediaURL), - }) - } + return PromptContext{}, fmt.Errorf("video attachments must be preprocessed into text before prompt assembly") default: return PromptContext{}, fmt.Errorf("unsupported media type: %s", mediaType) } - result.PromptContext.Messages = append(result.PromptContext.Messages, PromptMessage{ + + promptContext, text, err := oc.buildCurrentTurnText(ctx, portal, meta, caption, eventID, currentTurnTextOptions{ + includeLinkScope: true, + append: appendBlocks, + }) + if err != nil { + return PromptContext{}, err + } + if strings.TrimSpace(text) != "" { + blocks = append([]PromptBlock{{Type: PromptBlockText, Text: text}}, blocks...) + } + promptContext.Messages = append(promptContext.Messages, PromptMessage{ Role: PromptRoleUser, Blocks: blocks, }) - return result.PromptContext, nil + return promptContext, nil } // buildPromptUpToMessage builds a prompt including messages up to and including the specified message @@ -2006,55 +1927,27 @@ func (oc *AIClient) buildContextUpToMessage( targetMessageID networkid.MessageID, newBody string, ) (PromptContext, error) { - var promptContext PromptContext - bridgesdk.AppendChatMessagesToPromptContext(&promptContext.PromptContext, oc.buildSystemMessages(ctx, portal, meta)) - - hr, err := oc.fetchHistoryRows(ctx, portal, meta) + base := PromptContext{ + SystemPrompt: oc.buildConversationSystemPromptText(ctx, portal, meta, false), + } + historyMessages, err := oc.replayHistoryMessages(ctx, portal, meta, historyReplayOptions{ + mode: historyReplayRewrite, + targetMessageID: targetMessageID, + }) if err != nil { return PromptContext{}, err } - if hr != nil { - // Add messages up to the target message, replacing the target with newBody - for i := len(hr.rows) - 1; i >= 0; i-- { - msg := hr.rows[i] - msgMeta := messageMeta(msg) - - // Stop after adding the target message - if msg.ID == targetMessageID { - body := newBody - promptContext.Messages = append(promptContext.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: body, - }}, - }) - break - } - - if !shouldIncludeInHistory(msgMeta) { - continue - } - if hr.resetAt > 0 && msg.Timestamp.UnixMilli() < hr.resetAt { - continue - } - - injectImages := hr.hasVision && i < maxHistoryImageMessages - promptContext.Messages = append(promptContext.Messages, oc.historyMessageBundle(ctx, msgMeta, injectImages)...) - } - } else { - body := strings.TrimSpace(newBody) - body = airuntime.SanitizeChatMessageForDisplay(body, true) - promptContext.Messages = append(promptContext.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: body, - }}, - }) - } - - return promptContext, nil + base.Messages = append(base.Messages, historyMessages...) + body := strings.TrimSpace(newBody) + body = airuntime.SanitizeChatMessageForDisplay(body, true) + base.Messages = append(base.Messages, PromptMessage{ + Role: PromptRoleUser, + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: body, + }}, + }) + return base, nil } // downloadAndEncodeMedia downloads media and returns base64-encoded data. @@ -2261,7 +2154,6 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { Timestamp: agentremote.MatrixEventTimestamp(last.Event), } setCanonicalTurnDataFromPromptMessages(userMessage.Metadata.(*MessageMetadata), promptTail(promptContext, 1)) - ensureCanonicalUserMessage(userMessage) // Save user message to database - we must do this ourselves since we already // returned Pending: true to the bridge framework when debouncing started diff --git a/bridges/ai/constructors.go b/bridges/ai/constructors.go index 6f8d3687..2be3a0f4 100644 --- a/bridges/ai/constructors.go +++ b/bridges/ai/constructors.go @@ -59,7 +59,7 @@ func NewAIConnector() *OpenAIConnector { BeeperBridgeType: "ai", DefaultPort: 29345, DefaultCommandPrefix: func() string { - return oc.Config.Bridge.CommandPrefix + return bridgesdk.ResolveCommandPrefix(oc.Config.Bridge.CommandPrefix, "!ai") }, ExampleConfig: exampleNetworkConfig, ConfigData: &oc.Config, diff --git a/bridges/ai/constructors_test.go b/bridges/ai/constructors_test.go index 362c5280..fa7fca51 100644 --- a/bridges/ai/constructors_test.go +++ b/bridges/ai/constructors_test.go @@ -39,6 +39,9 @@ func TestNewAIConnectorUsesSDKConfig(t *testing.T) { if name.DefaultPort != 29345 { t.Fatalf("unexpected default port %d", name.DefaultPort) } + if name.DefaultCommandPrefix != "!ai" { + t.Fatalf("unexpected default command prefix %q", name.DefaultCommandPrefix) + } } func TestNewAIConnectorInitializesClientCacheMap(t *testing.T) { diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index 90248b20..04a360fc 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -17,7 +17,6 @@ import ( "github.com/beeper/agentremote" airuntime "github.com/beeper/agentremote/pkg/runtime" "github.com/beeper/agentremote/pkg/shared/stringutil" - bridgesdk "github.com/beeper/agentremote/sdk" ) func messageSendStatusError(err error, message string, reason event.MessageStatusReason) error { @@ -572,11 +571,12 @@ func (oc *AIClient) handleMediaMessage( eventID = msg.Event.ID } - // Check capability (PDF has special OpenRouter handling via file-parser plugin) + // PDFs are normalized into file-context text before prompt assembly, so they + // do not require native provider file support. modelCaps := oc.getModelCapabilitiesForMeta(ctx, meta) supportsMedia := config.capabilityCheck(&modelCaps) - if isPDF && !supportsMedia && oc.isOpenRouterProvider() { - supportsMedia = true // OpenRouter supports PDF via file-parser plugin + if isPDF { + supportsMedia = true } queueSettings, _, _, _ := oc.resolveQueueSettingsForPortal(ctx, portal, meta, "", airuntime.QueueInlineOptions{}) @@ -703,30 +703,6 @@ func (oc *AIClient) handleMediaMessage( } } - // If model lacks audio but agent supports audio understanding, analyze audio first. - if msgType == event.MsgAudio { - audioModel, audioFallback := oc.resolveModelForCapability(ctx, meta, func(caps ModelCapabilities) bool { return caps.SupportsAudio }, oc.resolveAudioUnderstandingModel) - if resp, err := oc.dispatchMediaUnderstandingFallback( - ctx, - audioModel, - audioFallback, - string(mediaURL), - mimeType, - encryptedFile, - caption, - hasUserCaption, - buildMediaUnderstandingPrompt(MediaCapabilityAudio), - oc.analyzeAudioWithModel, - buildMediaUnderstandingMessage("Audio", "Transcript"), - "Audio understanding failed", - "audio understanding produced empty result", - "Couldn't analyze the audio. Try again, or switch to an audio-capable model with !ai model.", - dispatchTextOnly, - ); resp != nil || err != nil { - return resp, err - } - } - return nil, agentremote.UnsupportedMessageStatus(fmt.Errorf( "current model (%s) does not support %s; switch to a capable model using !ai model", oc.effectiveModel(meta), config.capabilityName, @@ -1102,70 +1078,20 @@ func (oc *AIClient) buildContextForRegenerate( latestUserBody string, latestUserID id.EventID, ) (PromptContext, error) { - var promptContext PromptContext - bridgesdk.AppendChatMessagesToPromptContext(&promptContext.PromptContext, oc.buildSystemMessages(ctx, portal, meta)) - - historyLimit := oc.historyLimit(ctx, portal, meta) - resetAt := int64(0) - if meta != nil { - resetAt = meta.SessionResetAt + base := PromptContext{ + SystemPrompt: oc.buildConversationSystemPromptText(ctx, portal, meta, false), } - if historyLimit > 0 { - history, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(ctx, portal.PortalKey, historyLimit+2) - if err != nil { - return PromptContext{}, fmt.Errorf("failed to load prompt history: %w", err) - } - - // Determine whether to inject images into history (requires vision-capable model). - hasVision := oc.getModelCapabilitiesForMeta(ctx, meta).SupportsVision - historyBundles := make([][]PromptMessage, 0, len(history)) - - // Skip the most recent messages (last user and assistant) and build from older history - skippedUser := false - skippedAssistant := false - includedCount := 0 - for _, msg := range history { - msgMeta := messageMeta(msg) - // Skip commands and non-conversation messages - if !shouldIncludeInHistory(msgMeta) { - continue - } - if resetAt > 0 && msg.Timestamp.UnixMilli() < resetAt { - continue - } - - // Skip the last user message and last assistant message - if !skippedUser && msgMeta.Role == "user" { - skippedUser = true - continue - } - if !skippedAssistant && msgMeta.Role == "assistant" { - skippedAssistant = true - continue - } - - // Only inject images for recent messages and vision-capable models. - // This loop builds newest-to-oldest, so early entries are the most recent. - injectImages := hasVision && includedCount < maxHistoryImageMessages - includedCount++ - bundle := oc.historyMessageBundle(ctx, msgMeta, injectImages) - if len(bundle) > 0 { - historyBundles = append(historyBundles, bundle) - } - } - - for i := len(historyBundles) - 1; i >= 0; i-- { - promptContext.Messages = append(promptContext.Messages, historyBundles[i]...) - } + historyMessages, err := oc.replayHistoryMessages(ctx, portal, meta, historyReplayOptions{mode: historyReplayRegen}) + if err != nil { + return PromptContext{}, err } - - latest := latestUserBody - promptContext.Messages = append(promptContext.Messages, PromptMessage{ + base.Messages = append(base.Messages, historyMessages...) + base.Messages = append(base.Messages, PromptMessage{ Role: PromptRoleUser, Blocks: []PromptBlock{{ Type: PromptBlockText, - Text: latest, + Text: latestUserBody, }}, }) - return promptContext, nil + return base, nil } diff --git a/bridges/ai/heartbeat_execute.go b/bridges/ai/heartbeat_execute.go index d1297bec..27d9a57c 100644 --- a/bridges/ai/heartbeat_execute.go +++ b/bridges/ai/heartbeat_execute.go @@ -283,7 +283,7 @@ func systemEventsOwnerKey(oc *AIClient) string { } func (oc *AIClient) buildContextWithHeartbeat(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, prompt string) (PromptContext, error) { - base, err := oc.buildBaseContext(ctx, portal, meta) + base, text, err := oc.buildCurrentTurnText(ctx, portal, meta, prompt, "", currentTurnTextOptions{}) if err != nil { return PromptContext{}, err } @@ -291,7 +291,7 @@ func (oc *AIClient) buildContextWithHeartbeat(ctx context.Context, portal *bridg Role: PromptRoleUser, Blocks: []PromptBlock{{ Type: PromptBlockText, - Text: prompt, + Text: text, }}, }) return base, nil diff --git a/bridges/ai/identifiers.go b/bridges/ai/identifiers.go index 53abbfe6..f28b506e 100644 --- a/bridges/ai/identifiers.go +++ b/bridges/ai/identifiers.go @@ -200,11 +200,7 @@ func shouldIncludeInHistory(meta *MessageMetadata) bool { if meta.Role != "user" && meta.Role != "assistant" { return false } - return len(meta.CanonicalTurnData) > 0 || - strings.TrimSpace(meta.Body) != "" || - len(meta.ToolCalls) > 0 || - strings.TrimSpace(meta.MediaURL) != "" || - len(meta.GeneratedFiles) > 0 + return len(meta.CanonicalTurnData) > 0 } func loginMetadata(login *bridgev2.UserLogin) *UserLoginMetadata { diff --git a/bridges/ai/image_understanding.go b/bridges/ai/image_understanding.go index 49dd9bf2..83fa17f3 100644 --- a/bridges/ai/image_understanding.go +++ b/bridges/ai/image_understanding.go @@ -7,8 +7,6 @@ import ( "strings" "maunium.net/go/mautrix/event" - - bridgesdk "github.com/beeper/agentremote/sdk" ) func (oc *AIClient) canUseMediaUnderstanding(meta *PortalMetadata) bool { @@ -147,17 +145,6 @@ func (oc *AIClient) resolveVisionModelForImage(ctx context.Context, meta *Portal ) } -// resolveAudioUnderstandingModel returns an audio-capable model from the agent's model chain. -func (oc *AIClient) resolveAudioUnderstandingModel(ctx context.Context, meta *PortalMetadata) string { - return oc.resolveUnderstandingModel( - ctx, - meta, - func(caps ModelCapabilities) bool { return caps.SupportsAudio }, - func(info ModelInfo) bool { return info.SupportsAudio }, - "audio", - ) -} - func (oc *AIClient) pickModelFromCache(cache *ModelCache, provider string, supports modelInfoFilter) string { if cache == nil || len(cache.Models) == 0 { return "" @@ -224,9 +211,9 @@ func (oc *AIClient) analyzeImageWithModel( actualMimeType = "image/jpeg" } - dataURL := bridgesdk.BuildDataURL(actualMimeType, b64Data) + dataURL := BuildDataURL(actualMimeType, b64Data) - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( + ctxPrompt := UserPromptContext( PromptBlock{ Type: PromptBlockImage, ImageURL: dataURL, @@ -236,7 +223,7 @@ func (oc *AIClient) analyzeImageWithModel( Type: PromptBlockText, Text: prompt, }, - )} + ) resp, err := oc.provider.Generate(ctx, GenerateParams{ Model: modelIDForAPI, @@ -249,67 +236,6 @@ func (oc *AIClient) analyzeImageWithModel( return strings.TrimSpace(resp.Content), nil } - -func (oc *AIClient) analyzeAudioWithModel( - ctx context.Context, - modelID string, - audioURL string, - mimeType string, - encryptedFile *event.EncryptedFileInfo, - prompt string, -) (string, error) { - if strings.TrimSpace(modelID) == "" { - return "", errors.New("missing model for audio analysis") - } - if strings.TrimSpace(prompt) == "" { - prompt = defaultPromptByCapability[MediaCapabilityAudio] - } - - modelIDForAPI := oc.modelIDForAPI(modelID) - audioRef := mediaSourceLabel(audioURL, encryptedFile) - b64Data, actualMimeType, err := oc.downloadMediaBase64(ctx, audioURL, encryptedFile, 25, mimeType) - if err != nil { - return "", fmt.Errorf("failed to download audio %s for model %s: %w", audioRef, modelIDForAPI, err) - } - actualMimeType = strings.TrimSpace(actualMimeType) - if actualMimeType == "" { - actualMimeType = strings.TrimSpace(mimeType) - } - format := getAudioFormat(actualMimeType) - if format == "" { - format = "mp3" - } - - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( - PromptBlock{ - Type: PromptBlockAudio, - AudioB64: b64Data, - AudioFormat: format, - }, - PromptBlock{ - Type: PromptBlockText, - Text: prompt, - }, - )} - - params := GenerateParams{ - Model: modelIDForAPI, - Context: ctxPrompt, - MaxCompletionTokens: defaultImageUnderstandingLimit, - } - var resp *GenerateResponse - if provider, ok := oc.provider.(*OpenAIProvider); ok { - resp, err = provider.generateChatCompletions(ctx, params) - } else { - resp, err = oc.provider.Generate(ctx, params) - } - if err != nil { - return "", fmt.Errorf("audio analysis failed for model %s (audio %s): %w", modelIDForAPI, audioRef, err) - } - - return strings.TrimSpace(resp.Content), nil -} - func mediaSourceLabel(mediaURL string, encryptedFile *event.EncryptedFileInfo) string { source := strings.TrimSpace(mediaURL) if encryptedFile != nil && encryptedFile.URL != "" { diff --git a/bridges/ai/internal_dispatch.go b/bridges/ai/internal_dispatch.go index 21fddb9e..bfc3913d 100644 --- a/bridges/ai/internal_dispatch.go +++ b/bridges/ai/internal_dispatch.go @@ -60,7 +60,6 @@ func (oc *AIClient) dispatchInternalMessage( Timestamp: time.Now(), } setCanonicalTurnDataFromPromptMessages(userMessage.Metadata.(*MessageMetadata), promptTail(promptContext, 1)) - ensureCanonicalUserMessage(userMessage) if _, err := oc.UserLogin.Bridge.GetGhostByID(ctx, userMessage.SenderID); err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure user ghost before saving internal message") } diff --git a/bridges/ai/media_understanding_runner.go b/bridges/ai/media_understanding_runner.go index 1b468891..49962119 100644 --- a/bridges/ai/media_understanding_runner.go +++ b/bridges/ai/media_understanding_runner.go @@ -17,7 +17,6 @@ import ( "maunium.net/go/mautrix/event" "github.com/beeper/agentremote/pkg/shared/stringutil" - bridgesdk "github.com/beeper/agentremote/sdk" ) type mediaUnderstandingResult struct { @@ -705,9 +704,9 @@ func (oc *AIClient) describeImageWithEntry( actualMime = "image/jpeg" } b64Data := base64.StdEncoding.EncodeToString(rawData) - dataURL := bridgesdk.BuildDataURL(actualMime, b64Data) + dataURL := BuildDataURL(actualMime, b64Data) - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( + ctxPrompt := UserPromptContext( PromptBlock{ Type: PromptBlockText, Text: prompt, @@ -717,7 +716,7 @@ func (oc *AIClient) describeImageWithEntry( ImageURL: dataURL, MimeType: actualMime, }, - )} + ) modelIDForAPI := oc.modelIDForAPI(ResolveAlias(modelID)) var resp *GenerateResponse if entryProvider == "openrouter" { @@ -849,34 +848,6 @@ func (oc *AIClient) describeVideoWithEntry( return nil, errors.New("video payload exceeds base64 limit") } - if providerID == "openrouter" { - modelID := strings.TrimSpace(entry.Model) - if modelID == "" { - return nil, errors.New("video understanding requires model id") - } - videoB64 := base64.StdEncoding.EncodeToString(data) - - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( - PromptBlock{ - Type: PromptBlockText, - Text: prompt, - }, - PromptBlock{ - Type: PromptBlockVideo, - VideoB64: videoB64, - MimeType: actualMime, - }, - )} - modelIDForAPI := oc.modelIDForAPI(ResolveAlias(modelID)) - var resp *GenerateResponse - resp, err = oc.generateWithOpenRouter(ctx, modelIDForAPI, ctxPrompt, capCfg, entry) - if err != nil { - return nil, err - } - text := strings.TrimSpace(resp.Content) - text = truncateText(text, maxChars) - return buildMediaOutput(MediaCapabilityVideo, text, entry.Provider, modelID, attachmentIndex), nil - } if providerID != "google" { return nil, fmt.Errorf("unsupported video provider: %s", providerID) } @@ -926,9 +897,6 @@ func (oc *AIClient) generateWithOpenRouter( Context: promptContext, MaxCompletionTokens: defaultImageUnderstandingLimit, } - if bridgesdk.PromptContextHasBlockType(promptContext.PromptContext, PromptBlockAudio, PromptBlockVideo) { - return provider.generateChatCompletions(ctx, params) - } return provider.Generate(ctx, params) } diff --git a/bridges/ai/messages.go b/bridges/ai/messages.go index 0504b278..1efd3ce9 100644 --- a/bridges/ai/messages.go +++ b/bridges/ai/messages.go @@ -1,32 +1,62 @@ package ai -import bridgesdk "github.com/beeper/agentremote/sdk" - -type PromptRole = bridgesdk.PromptRole +type PromptRole string const ( - PromptRoleUser PromptRole = bridgesdk.PromptRoleUser - PromptRoleAssistant PromptRole = bridgesdk.PromptRoleAssistant - PromptRoleToolResult PromptRole = bridgesdk.PromptRoleToolResult + PromptRoleUser PromptRole = "user" + PromptRoleAssistant PromptRole = "assistant" + PromptRoleToolResult PromptRole = "tool_result" ) -type PromptBlockType = bridgesdk.PromptBlockType +type PromptBlockType string const ( - PromptBlockText PromptBlockType = bridgesdk.PromptBlockText - PromptBlockImage PromptBlockType = bridgesdk.PromptBlockImage - PromptBlockFile PromptBlockType = bridgesdk.PromptBlockFile - PromptBlockThinking PromptBlockType = bridgesdk.PromptBlockThinking - PromptBlockToolCall PromptBlockType = bridgesdk.PromptBlockToolCall - PromptBlockAudio PromptBlockType = bridgesdk.PromptBlockAudio - PromptBlockVideo PromptBlockType = bridgesdk.PromptBlockVideo + PromptBlockText PromptBlockType = "text" + PromptBlockImage PromptBlockType = "image" + PromptBlockThinking PromptBlockType = "thinking" + PromptBlockToolCall PromptBlockType = "tool_call" ) -type PromptBlock = bridgesdk.PromptBlock -type PromptMessage = bridgesdk.PromptMessage +type PromptBlock struct { + Type PromptBlockType + + Text string + + ImageURL string + ImageB64 string + MimeType string + + ToolCallID string + ToolName string + ToolCallArguments string +} + +type PromptMessage struct { + Role PromptRole + Blocks []PromptBlock + ToolCallID string + ToolName string + IsError bool +} + +func (m PromptMessage) Text() string { + var text string + for _, block := range m.Blocks { + switch block.Type { + case PromptBlockText, PromptBlockThinking: + if text == "" { + text = block.Text + } else if block.Text != "" { + text += "\n" + block.Text + } + } + } + return text +} -// PromptContext extends the shared provider-facing prompt model with bridge-local tool definitions. +// PromptContext is the bridge-local prompt envelope used throughout bridges/ai. type PromptContext struct { - bridgesdk.PromptContext - Tools []ToolDefinition + SystemPrompt string + Messages []PromptMessage + Tools []ToolDefinition } diff --git a/bridges/ai/messages_responses_input_test.go b/bridges/ai/messages_responses_input_test.go index d414b784..9a3a09ee 100644 --- a/bridges/ai/messages_responses_input_test.go +++ b/bridges/ai/messages_responses_input_test.go @@ -4,15 +4,12 @@ import ( "testing" "github.com/openai/openai-go/v3/responses" - - bridgesdk "github.com/beeper/agentremote/sdk" ) func TestPromptContextToResponsesInput_MultimodalUser(t *testing.T) { - input := bridgesdk.PromptContextToResponsesInput(bridgesdk.UserPromptContext( + input := PromptContextToResponsesInput(UserPromptContext( PromptBlock{Type: PromptBlockText, Text: "hello"}, PromptBlock{Type: PromptBlockImage, ImageB64: "aGVsbG8=", MimeType: "image/png"}, - PromptBlock{Type: PromptBlockFile, FileB64: "cGRm", Filename: "document.pdf"}, )) if len(input) != 1 { t.Fatalf("expected 1 input item, got %d", len(input)) @@ -33,7 +30,6 @@ func TestPromptContextToResponsesInput_MultimodalUser(t *testing.T) { foundText := false foundImage := false - foundFile := false for _, part := range parts { if part.OfInputText != nil { foundText = true @@ -47,18 +43,9 @@ func TestPromptContextToResponsesInput_MultimodalUser(t *testing.T) { t.Fatalf("expected image part data URL to preserve content, got %#v", part.OfInputImage.ImageURL.Value) } } - if part.OfInputFile != nil { - foundFile = true - if part.OfInputFile.Filename.Value != "document.pdf" { - t.Fatalf("expected file part filename document.pdf, got %#v", part.OfInputFile.Filename.Value) - } - if part.OfInputFile.FileData.Value != "cGRm" { - t.Fatalf("expected file part data to preserve content, got %#v", part.OfInputFile.FileData.Value) - } - } } - if !foundText || !foundImage || !foundFile { - t.Fatalf("expected text, image, and file parts (got text=%v image=%v file=%v)", foundText, foundImage, foundFile) + if !foundText || !foundImage { + t.Fatalf("expected text and image parts (got text=%v image=%v)", foundText, foundImage) } } diff --git a/bridges/ai/prompt_builder.go b/bridges/ai/prompt_builder.go new file mode 100644 index 00000000..0c704d70 --- /dev/null +++ b/bridges/ai/prompt_builder.go @@ -0,0 +1,178 @@ +package ai + +import ( + "context" + "strings" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type historyReplayMode string + +const ( + historyReplayNormal historyReplayMode = "normal" + historyReplayRegen historyReplayMode = "regenerate" + historyReplayRewrite historyReplayMode = "rewrite" +) + +type historyReplayOptions struct { + mode historyReplayMode + targetMessageID networkid.MessageID +} + +type currentTurnTextOptions struct { + rawEventContent map[string]any + includeLinkScope bool + prepend []string + append []string +} + +func joinPromptFragments(parts ...string) string { + var filtered []string + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + filtered = append(filtered, part) + } + } + return strings.TrimSpace(strings.Join(filtered, "\n\n")) +} + +func (oc *AIClient) fetchHistoryRowsWithExtra( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + extra int, +) (*historyLoadResult, error) { + historyLimit := oc.historyLimit(ctx, portal, meta) + if historyLimit <= 0 { + return nil, nil + } + if extra > 0 { + historyLimit += extra + } + resetAt := int64(0) + if meta != nil { + resetAt = meta.SessionResetAt + } + history, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(ctx, portal.PortalKey, historyLimit) + if err != nil { + return nil, err + } + return &historyLoadResult{ + rows: history, + hasVision: oc.getModelCapabilitiesForMeta(ctx, meta).SupportsVision, + resetAt: resetAt, + }, nil +} + +func (oc *AIClient) replayHistoryMessages( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + opts historyReplayOptions, +) ([]PromptMessage, error) { + extra := 0 + if opts.mode == historyReplayRegen { + extra = 2 + } + hr, err := oc.fetchHistoryRowsWithExtra(ctx, portal, meta, extra) + if err != nil { + return nil, err + } + if hr == nil { + return nil, nil + } + + type replayCandidate struct { + row *database.Message + meta *MessageMetadata + } + + candidates := make([]replayCandidate, 0, len(hr.rows)) + for _, row := range hr.rows { + msgMeta := messageMeta(row) + if opts.mode == historyReplayRewrite && row.ID == opts.targetMessageID { + candidates = append(candidates, replayCandidate{row: row, meta: msgMeta}) + continue + } + if !shouldIncludeInHistory(msgMeta) { + continue + } + if hr.resetAt > 0 && row.Timestamp.UnixMilli() < hr.resetAt { + continue + } + candidates = append(candidates, replayCandidate{row: row, meta: msgMeta}) + } + + skipUserID := networkid.MessageID("") + skipAssistantID := networkid.MessageID("") + if opts.mode == historyReplayRegen { + for _, candidate := range candidates { + if skipUserID == "" && candidate.meta != nil && candidate.meta.Role == "user" { + skipUserID = candidate.row.ID + continue + } + if skipAssistantID == "" && candidate.meta != nil && candidate.meta.Role == "assistant" { + skipAssistantID = candidate.row.ID + } + if skipUserID != "" && skipAssistantID != "" { + break + } + } + } + + var messages []PromptMessage + for i := len(candidates) - 1; i >= 0; i-- { + candidate := candidates[i] + if opts.mode == historyReplayRewrite && candidate.row.ID == opts.targetMessageID { + break + } + if candidate.row.ID == skipUserID || candidate.row.ID == skipAssistantID { + continue + } + injectImages := hr.hasVision && i < maxHistoryImageMessages + messages = append(messages, oc.historyMessageBundle(ctx, candidate.meta, injectImages)...) + } + return messages, nil +} + +func (oc *AIClient) buildCurrentTurnText( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + userText string, + eventID id.EventID, + opts currentTurnTextOptions, +) (PromptContext, string, error) { + result, err := oc.prepareInboundPromptContext(ctx, portal, meta, userText, eventID) + if err != nil { + return PromptContext{}, "", err + } + + prepend := append([]string{}, opts.prepend...) + if portal != nil && portal.MXID != "" { + reactionFeedback := DrainReactionFeedback(portal.MXID) + if len(reactionFeedback) > 0 { + if feedbackText := FormatReactionFeedback(reactionFeedback); feedbackText != "" { + prepend = append(prepend, feedbackText) + } + } + } + if result.UntrustedPrefix != "" { + prepend = append(prepend, result.UntrustedPrefix) + } + + appendParts := append([]string{}, opts.append...) + if opts.includeLinkScope { + if linkContext := oc.buildLinkContext(ctx, userText, opts.rawEventContent); linkContext != "" { + appendParts = append(appendParts, linkContext) + } + } + + body := joinPromptFragments(append(append(prepend, result.ResolvedBody), appendParts...)...) + return result.PromptContext, body, nil +} diff --git a/bridges/ai/prompt_context_local.go b/bridges/ai/prompt_context_local.go new file mode 100644 index 00000000..fccf8944 --- /dev/null +++ b/bridges/ai/prompt_context_local.go @@ -0,0 +1,371 @@ +package ai + +import ( + "fmt" + "slices" + "strings" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" +) + +func UserPromptContext(blocks ...PromptBlock) PromptContext { + return PromptContext{ + Messages: []PromptMessage{{ + Role: PromptRoleUser, + Blocks: slices.Clone(blocks), + }}, + } +} + +func AppendPromptText(dst *string, text string) { + text = strings.TrimSpace(text) + if text == "" { + return + } + if *dst == "" { + *dst = text + return + } + *dst = strings.TrimSpace(*dst + "\n\n" + text) +} + +func BuildDataURL(mimeType, b64Data string) string { + return fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data) +} + +func resolveBlockImageURL(block PromptBlock) string { + imageURL := strings.TrimSpace(block.ImageURL) + if imageURL == "" && block.ImageB64 != "" { + mimeType := strings.TrimSpace(block.MimeType) + if mimeType == "" { + mimeType = "image/jpeg" + } + imageURL = BuildDataURL(mimeType, block.ImageB64) + } + return imageURL +} + +func PromptContextToResponsesInput(ctx PromptContext) responses.ResponseInputParam { + var result responses.ResponseInputParam + for _, msg := range ctx.Messages { + result = append(result, promptMessageToResponsesInputs(msg)...) + } + return result +} + +func promptMessageToResponsesInputs(msg PromptMessage) responses.ResponseInputParam { + switch msg.Role { + case PromptRoleUser: + content := make([]responses.ResponseInputContentUnionParam, 0, len(msg.Blocks)) + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText: + text := strings.TrimSpace(block.Text) + if text == "" { + continue + } + content = append(content, responses.ResponseInputContentUnionParam{ + OfInputText: &responses.ResponseInputTextParam{Text: text}, + }) + case PromptBlockImage: + imageURL := resolveBlockImageURL(block) + if imageURL == "" { + continue + } + content = append(content, responses.ResponseInputContentUnionParam{ + OfInputImage: &responses.ResponseInputImageParam{ + ImageURL: param.NewOpt(imageURL), + }, + }) + } + } + if len(content) == 0 { + return nil + } + return responses.ResponseInputParam{{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleUser, + Content: responses.EasyInputMessageContentUnionParam{OfInputItemContentList: content}, + }, + }} + case PromptRoleAssistant: + var result responses.ResponseInputParam + text := strings.TrimSpace(msg.Text()) + if text != "" { + result = append(result, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.String(text)}, + }, + }) + } + for _, block := range msg.Blocks { + if block.Type != PromptBlockToolCall || strings.TrimSpace(block.ToolCallID) == "" || strings.TrimSpace(block.ToolName) == "" { + continue + } + args := strings.TrimSpace(block.ToolCallArguments) + if args == "" { + args = "{}" + } + result = append(result, responses.ResponseInputItemParamOfFunctionCall(args, block.ToolCallID, block.ToolName)) + } + return result + case PromptRoleToolResult: + text := strings.TrimSpace(msg.Text()) + if strings.TrimSpace(msg.ToolCallID) == "" || text == "" { + return nil + } + return responses.ResponseInputParam{buildFunctionCallOutputItem(msg.ToolCallID, text, false)} + default: + return nil + } +} + +func PromptContextToChatCompletionMessages(ctx PromptContext, supportsVideoURL bool) []openai.ChatCompletionMessageParamUnion { + var messages []openai.ChatCompletionMessageParamUnion + if strings.TrimSpace(ctx.SystemPrompt) != "" { + messages = append(messages, openai.SystemMessage(strings.TrimSpace(ctx.SystemPrompt))) + } + for _, msg := range ctx.Messages { + switch msg.Role { + case PromptRoleUser: + user := promptUserToChatMessage(msg) + if user != nil { + messages = append(messages, openai.ChatCompletionMessageParamUnion{OfUser: user}) + } + case PromptRoleAssistant: + assistant := promptAssistantToChatMessage(msg) + if assistant != nil { + messages = append(messages, openai.ChatCompletionMessageParamUnion{OfAssistant: assistant}) + } + case PromptRoleToolResult: + tool := promptToolToChatMessage(msg) + if tool != nil { + messages = append(messages, openai.ChatCompletionMessageParamUnion{OfTool: tool}) + } + } + } + return messages +} + +func promptUserToChatMessage(msg PromptMessage) *openai.ChatCompletionUserMessageParam { + var contentParts []openai.ChatCompletionContentPartUnionParam + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText: + text := strings.TrimSpace(block.Text) + if text == "" { + continue + } + contentParts = append(contentParts, openai.ChatCompletionContentPartUnionParam{ + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: text, + }, + }) + case PromptBlockImage: + imageURL := resolveBlockImageURL(block) + if imageURL == "" { + continue + } + contentParts = append(contentParts, openai.ChatCompletionContentPartUnionParam{ + OfImageURL: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: imageURL, + }, + }, + }) + } + } + if len(contentParts) == 0 { + return nil + } + return &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{OfArrayOfContentParts: contentParts}, + } +} + +func promptAssistantToChatMessage(msg PromptMessage) *openai.ChatCompletionAssistantMessageParam { + var contentParts []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion + var toolCalls []openai.ChatCompletionMessageToolCallUnionParam + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText, PromptBlockThinking: + text := strings.TrimSpace(block.Text) + if text == "" { + continue + } + contentParts = append(contentParts, openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{ + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: text, + }, + }) + case PromptBlockToolCall: + if strings.TrimSpace(block.ToolCallID) == "" || strings.TrimSpace(block.ToolName) == "" { + continue + } + toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallUnionParam{ + OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ + ID: block.ToolCallID, + Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ + Name: block.ToolName, + Arguments: block.ToolCallArguments, + }, + }, + }) + } + } + if len(contentParts) == 0 && len(toolCalls) == 0 { + return nil + } + return &openai.ChatCompletionAssistantMessageParam{ + Content: openai.ChatCompletionAssistantMessageParamContentUnion{OfArrayOfContentParts: contentParts}, + ToolCalls: toolCalls, + } +} + +func promptToolToChatMessage(msg PromptMessage) *openai.ChatCompletionToolMessageParam { + text := strings.TrimSpace(msg.Text()) + if strings.TrimSpace(msg.ToolCallID) == "" || text == "" { + return nil + } + return &openai.ChatCompletionToolMessageParam{ + ToolCallID: msg.ToolCallID, + Content: openai.ChatCompletionToolMessageParamContentUnion{ + OfString: openai.String(text), + }, + } +} + +func ChatMessagesToPromptContext(messages []openai.ChatCompletionMessageParamUnion) PromptContext { + var ctx PromptContext + for _, msg := range messages { + appendChatMessageToPromptContext(&ctx, msg) + } + return ctx +} + +func appendChatMessageToPromptContext(ctx *PromptContext, msg openai.ChatCompletionMessageParamUnion) { + if ctx == nil { + return + } + switch { + case msg.OfSystem != nil: + AppendPromptText(&ctx.SystemPrompt, extractChatSystemText(msg.OfSystem.Content)) + case msg.OfUser != nil: + ctx.Messages = append(ctx.Messages, promptMessageFromChatUser(msg.OfUser)) + case msg.OfAssistant != nil: + ctx.Messages = append(ctx.Messages, promptMessageFromChatAssistant(msg.OfAssistant)) + case msg.OfTool != nil: + ctx.Messages = append(ctx.Messages, promptMessageFromChatTool(msg.OfTool)) + } +} + +func extractChatSystemText(content openai.ChatCompletionSystemMessageParamContentUnion) string { + if content.OfString.Value != "" { + return content.OfString.Value + } + var values []string + for _, part := range content.OfArrayOfContentParts { + if text := strings.TrimSpace(part.Text); text != "" { + values = append(values, text) + } + } + return strings.Join(values, "\n") +} + +func promptMessageFromChatUser(msg *openai.ChatCompletionUserMessageParam) PromptMessage { + pm := PromptMessage{Role: PromptRoleUser} + if msg == nil { + return pm + } + if msg.Content.OfString.Value != "" { + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: msg.Content.OfString.Value}) + } + for _, part := range msg.Content.OfArrayOfContentParts { + switch { + case part.OfText != nil: + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: part.OfText.Text}) + case part.OfImageURL != nil: + pm.Blocks = append(pm.Blocks, PromptBlock{ + Type: PromptBlockImage, + ImageURL: part.OfImageURL.ImageURL.URL, + MimeType: inferPromptMimeTypeFromDataURL(part.OfImageURL.ImageURL.URL), + }) + } + } + return pm +} + +func promptMessageFromChatAssistant(msg *openai.ChatCompletionAssistantMessageParam) PromptMessage { + pm := PromptMessage{Role: PromptRoleAssistant} + if msg == nil { + return pm + } + if msg.Content.OfString.Value != "" { + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: msg.Content.OfString.Value}) + } + for _, part := range msg.Content.OfArrayOfContentParts { + if part.OfText == nil { + continue + } + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: part.OfText.Text}) + } + for _, toolCall := range msg.ToolCalls { + if toolCall.OfFunction == nil { + continue + } + pm.Blocks = append(pm.Blocks, PromptBlock{ + Type: PromptBlockToolCall, + ToolCallID: toolCall.OfFunction.ID, + ToolName: toolCall.OfFunction.Function.Name, + ToolCallArguments: toolCall.OfFunction.Function.Arguments, + }) + } + return pm +} + +func promptMessageFromChatTool(msg *openai.ChatCompletionToolMessageParam) PromptMessage { + pm := PromptMessage{Role: PromptRoleToolResult} + if msg == nil { + return pm + } + pm.ToolCallID = msg.ToolCallID + if msg.Content.OfString.Value != "" { + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: msg.Content.OfString.Value}) + } + for _, part := range msg.Content.OfArrayOfContentParts { + if strings.TrimSpace(part.Text) == "" { + continue + } + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) + } + return pm +} + +func inferPromptMimeTypeFromDataURL(value string) string { + value = strings.TrimSpace(value) + rest, ok := strings.CutPrefix(value, "data:") + if !ok { + return "" + } + idx := strings.Index(rest, ";") + if idx <= 0 { + return "" + } + return rest[:idx] +} + +func HasUnsupportedResponsesPromptContext(ctx PromptContext) bool { + for _, msg := range ctx.Messages { + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText, PromptBlockImage, PromptBlockThinking, PromptBlockToolCall: + default: + return true + } + } + } + return false +} diff --git a/bridges/ai/prompt_projection_local.go b/bridges/ai/prompt_projection_local.go new file mode 100644 index 00000000..a31592e4 --- /dev/null +++ b/bridges/ai/prompt_projection_local.go @@ -0,0 +1,166 @@ +package ai + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/beeper/agentremote/sdk" +) + +func promptMessagesFromTurnData(td sdk.TurnData) []PromptMessage { + if td.Role == "" { + return nil + } + switch td.Role { + case "user": + msg := PromptMessage{Role: PromptRoleUser} + for _, part := range td.Parts { + switch normalizePromptTurnPartType(part.Type) { + case "text": + if strings.TrimSpace(part.Text) != "" { + msg.Blocks = append(msg.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) + } + case "image": + imageB64 := promptExtraString(part.Extra, "imageB64") + if strings.TrimSpace(part.URL) == "" && imageB64 == "" { + continue + } + msg.Blocks = append(msg.Blocks, PromptBlock{ + Type: PromptBlockImage, + ImageURL: part.URL, + ImageB64: imageB64, + MimeType: part.MediaType, + }) + } + } + if len(msg.Blocks) == 0 { + return nil + } + return []PromptMessage{msg} + case "assistant": + assistant := PromptMessage{Role: PromptRoleAssistant} + var results []PromptMessage + for _, part := range td.Parts { + switch normalizePromptTurnPartType(part.Type) { + case "text": + if strings.TrimSpace(part.Text) != "" { + assistant.Blocks = append(assistant.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) + } + case "reasoning": + text := strings.TrimSpace(part.Reasoning) + if text == "" { + text = strings.TrimSpace(part.Text) + } + if text != "" { + assistant.Blocks = append(assistant.Blocks, PromptBlock{Type: PromptBlockThinking, Text: text}) + } + case "tool": + if strings.TrimSpace(part.ToolCallID) != "" && strings.TrimSpace(part.ToolName) != "" { + assistant.Blocks = append(assistant.Blocks, PromptBlock{ + Type: PromptBlockToolCall, + ToolCallID: part.ToolCallID, + ToolName: part.ToolName, + ToolCallArguments: canonicalPromptToolArguments(part.Input), + }) + } + outputText := strings.TrimSpace(formatPromptCanonicalValue(part.Output)) + if outputText == "" { + outputText = strings.TrimSpace(part.ErrorText) + } + if outputText == "" && part.State == "output-denied" { + outputText = "Denied by user" + } + if strings.TrimSpace(part.ToolCallID) != "" && outputText != "" { + results = append(results, PromptMessage{ + Role: PromptRoleToolResult, + ToolCallID: part.ToolCallID, + ToolName: part.ToolName, + IsError: strings.TrimSpace(part.ErrorText) != "", + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: outputText, + }}, + }) + } + } + } + if len(assistant.Blocks) == 0 && len(results) == 0 { + return nil + } + out := make([]PromptMessage, 0, 1+len(results)) + if len(assistant.Blocks) > 0 { + out = append(out, assistant) + } + return append(out, results...) + default: + return nil + } +} + +func turnDataFromUserPromptMessages(messages []PromptMessage) (sdk.TurnData, bool) { + if len(messages) == 0 { + return sdk.TurnData{}, false + } + msg := messages[0] + if msg.Role != PromptRoleUser { + return sdk.TurnData{}, false + } + td := sdk.TurnData{Role: "user"} + td.Parts = make([]sdk.TurnPart, 0, len(msg.Blocks)) + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText: + if strings.TrimSpace(block.Text) != "" { + td.Parts = append(td.Parts, sdk.TurnPart{Type: "text", Text: block.Text}) + } + case PromptBlockImage: + if strings.TrimSpace(block.ImageURL) == "" && strings.TrimSpace(block.ImageB64) == "" { + continue + } + part := sdk.TurnPart{Type: "image", URL: block.ImageURL, MediaType: block.MimeType} + if strings.TrimSpace(block.ImageB64) != "" { + part.Extra = map[string]any{"imageB64": block.ImageB64} + } + td.Parts = append(td.Parts, part) + } + } + return td, len(td.Parts) > 0 +} + +func promptExtraString(extra map[string]any, key string) string { + if len(extra) == 0 { + return "" + } + value, _ := extra[key].(string) + return value +} + +func normalizePromptTurnPartType(partType string) string { + if partType == "dynamic-tool" { + return "tool" + } + return partType +} + +func canonicalPromptToolArguments(raw any) string { + if value := strings.TrimSpace(formatPromptCanonicalValue(raw)); value != "" { + return value + } + return "{}" +} + +func formatPromptCanonicalValue(raw any) string { + switch typed := raw.(type) { + case nil: + return "" + case string: + return typed + default: + data, err := json.Marshal(typed) + if err != nil { + return fmt.Sprint(typed) + } + return string(data) + } +} diff --git a/bridges/ai/provider_openai_chat.go b/bridges/ai/provider_openai_chat.go index f1e7be01..f16f0e1a 100644 --- a/bridges/ai/provider_openai_chat.go +++ b/bridges/ai/provider_openai_chat.go @@ -6,12 +6,10 @@ import ( "fmt" "github.com/openai/openai-go/v3" - - bridgesdk "github.com/beeper/agentremote/sdk" ) func (o *OpenAIProvider) generateChatCompletions(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { - chatMessages := bridgesdk.PromptContextToChatCompletionMessages(params.Context.PromptContext, isOpenRouterBaseURL(o.baseURL)) + chatMessages := PromptContextToChatCompletionMessages(params.Context, isOpenRouterBaseURL(o.baseURL)) if len(chatMessages) == 0 { return nil, errors.New("no chat messages for completion") } diff --git a/bridges/ai/provider_openai_responses.go b/bridges/ai/provider_openai_responses.go index 1e18e08e..d934c5a0 100644 --- a/bridges/ai/provider_openai_responses.go +++ b/bridges/ai/provider_openai_responses.go @@ -8,8 +8,6 @@ import ( "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" - - bridgesdk "github.com/beeper/agentremote/sdk" ) // reasoningEffortMap maps string effort levels to SDK constants. @@ -24,7 +22,7 @@ func (o *OpenAIProvider) buildResponsesParams(params GenerateParams) responses.R responsesParams := responses.ResponseNewParams{ Model: params.Model, Input: responses.ResponseNewParamsInputUnion{ - OfInputItemList: bridgesdk.PromptContextToResponsesInput(params.Context.PromptContext), + OfInputItemList: PromptContextToResponsesInput(params.Context), }, } @@ -60,7 +58,7 @@ func (o *OpenAIProvider) buildResponsesParams(params GenerateParams) responses.R // GenerateStream generates a streaming response from OpenAI using the Responses API. func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GenerateParams) (<-chan StreamEvent, error) { - if bridgesdk.HasUnsupportedResponsesPromptContext(params.Context.PromptContext) { + if HasUnsupportedResponsesPromptContext(params.Context) { return nil, fmt.Errorf("responses API does not support prompt context block types required by this request") } @@ -150,7 +148,7 @@ func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GeneratePara // Generate performs a non-streaming generation using the Responses API. func (o *OpenAIProvider) Generate(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { - if bridgesdk.HasUnsupportedResponsesPromptContext(params.Context.PromptContext) { + if HasUnsupportedResponsesPromptContext(params.Context) { return nil, fmt.Errorf("responses API does not support prompt context block types required by this request") } diff --git a/bridges/ai/provider_openai_responses_test.go b/bridges/ai/provider_openai_responses_test.go index 70e02059..9db3e9c7 100644 --- a/bridges/ai/provider_openai_responses_test.go +++ b/bridges/ai/provider_openai_responses_test.go @@ -1,37 +1,19 @@ package ai import ( - "context" - "strings" "testing" "go.mau.fi/util/ptr" - - bridgesdk "github.com/beeper/agentremote/sdk" ) -func TestGenerateStreamRejectsUnsupportedResponsesPromptContext(t *testing.T) { +func TestBuildResponsesParamsAcceptsBridgePromptContext(t *testing.T) { provider := &OpenAIProvider{} - params := GenerateParams{ - Context: PromptContext{ - PromptContext: bridgesdk.UserPromptContext(bridgesdk.PromptBlock{ - Type: bridgesdk.PromptBlockAudio, - AudioB64: "YXVkaW8=", - AudioFormat: "mp3", - MimeType: "audio/mpeg", - }), - }, - } - - events, err := provider.GenerateStream(context.Background(), params) - if err == nil { - t.Fatal("expected unsupported prompt context error") - } - if events != nil { - t.Fatal("expected nil event channel on validation failure") - } - if !strings.Contains(err.Error(), "responses API does not support prompt context block types required by this request") { - t.Fatalf("unexpected error: %v", err) + params := provider.buildResponsesParams(GenerateParams{ + Model: "gpt-5.2", + Context: UserPromptContext(PromptBlock{Type: PromptBlockText, Text: "hello"}), + }) + if len(params.Input.OfInputItemList) != 1 { + t.Fatalf("expected one input item, got %d", len(params.Input.OfInputItemList)) } } diff --git a/bridges/ai/response_retry.go b/bridges/ai/response_retry.go index 5731318f..e7531a62 100644 --- a/bridges/ai/response_retry.go +++ b/bridges/ai/response_retry.go @@ -14,7 +14,6 @@ import ( integrationruntime "github.com/beeper/agentremote/pkg/integrations/runtime" airuntime "github.com/beeper/agentremote/pkg/runtime" - bridgesdk "github.com/beeper/agentremote/sdk" ) const ( @@ -370,7 +369,7 @@ func (oc *AIClient) runAgentLoopWithRetry( } func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext PromptContext) (responseFunc, string) { - if bridgesdk.HasUnsupportedResponsesPromptContext(promptContext.PromptContext) { + if HasUnsupportedResponsesPromptContext(promptContext) { return oc.runChatCompletionsAgentLoop, "chat_completions" } modelID := "" diff --git a/bridges/ai/session_greeting.go b/bridges/ai/session_greeting.go index 5594d6df..7e9b75cd 100644 --- a/bridges/ai/session_greeting.go +++ b/bridges/ai/session_greeting.go @@ -20,18 +20,30 @@ func maybePrependSessionGreeting( prompt []openai.ChatCompletionMessageParamUnion, log zerolog.Logger, ) []openai.ChatCompletionMessageParamUnion { + if greeting := sessionGreetingFragment(ctx, portal, meta, log); greeting != "" { + return append([]openai.ChatCompletionMessageParamUnion{openai.SystemMessage(greeting)}, prompt...) + } + return prompt +} + +func sessionGreetingFragment( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + log zerolog.Logger, +) string { if meta == nil { - return prompt + return "" } agentID := strings.TrimSpace(resolveAgentID(meta)) if agentID == "" { - return prompt + return "" } if meta.SessionBootstrapByAgent == nil { meta.SessionBootstrapByAgent = make(map[string]int64) } if meta.SessionBootstrapByAgent[agentID] != 0 { - return prompt + return "" } meta.SessionBootstrapByAgent[agentID] = time.Now().UnixMilli() if portal != nil { @@ -39,6 +51,5 @@ func maybePrependSessionGreeting( log.Warn().Err(err).Msg("Failed to persist session bootstrap state") } } - greeting := openai.SystemMessage(sessionGreetingPrompt) - return append([]openai.ChatCompletionMessageParamUnion{greeting}, prompt...) + return sessionGreetingPrompt } diff --git a/bridges/ai/streaming_chat_completions.go b/bridges/ai/streaming_chat_completions.go index d2881188..33fb8986 100644 --- a/bridges/ai/streaming_chat_completions.go +++ b/bridges/ai/streaming_chat_completions.go @@ -211,7 +211,3 @@ func (oc *AIClient) runChatCompletionsAgentLoop( } }) } - -// convertToResponsesInput converts Chat Completion messages to Responses API input items -// Supports native multimodal content: images (ResponseInputImageParam), files/PDFs (ResponseInputFileParam) -// Note: Audio is handled via Chat Completions API fallback (SDK v3.16.0 lacks Responses API audio union support) diff --git a/bridges/ai/streaming_continuation.go b/bridges/ai/streaming_continuation.go index ea2a4ec5..93540f84 100644 --- a/bridges/ai/streaming_continuation.go +++ b/bridges/ai/streaming_continuation.go @@ -4,7 +4,6 @@ import ( "context" "strings" - "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" ) @@ -45,7 +44,7 @@ func (oc *AIClient) buildContinuationParams( state.baseInput = append(state.baseInput, steerInput...) } } - return oc.buildResponsesAgentLoopParams(ctx, meta, input, true) + return oc.buildResponsesAgentLoopParams(ctx, meta, state.baseSystemPrompt, input, true) } func (oc *AIClient) buildSteeringInputItems(prompts []string, meta *PortalMetadata) responses.ResponseInputParam { @@ -58,8 +57,9 @@ func (oc *AIClient) buildSteeringInputItems(prompts []string, meta *PortalMetada if prompt == "" { continue } - messages := []openai.ChatCompletionMessageParamUnion{openai.UserMessage(prompt)} - input = append(input, oc.convertToResponsesInput(messages, meta)...) + input = append(input, PromptContextToResponsesInput(UserPromptContext( + PromptBlock{Type: PromptBlockText, Text: prompt}, + ))...) } return input } diff --git a/bridges/ai/streaming_input_conversion.go b/bridges/ai/streaming_input_conversion.go index b360733a..010056e9 100644 --- a/bridges/ai/streaming_input_conversion.go +++ b/bridges/ai/streaming_input_conversion.go @@ -2,15 +2,8 @@ package ai import ( "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/responses" - - bridgesdk "github.com/beeper/agentremote/sdk" ) -func (oc *AIClient) convertToResponsesInput(messages []openai.ChatCompletionMessageParamUnion, _ *PortalMetadata) responses.ResponseInputParam { - return bridgesdk.PromptContextToResponsesInput(bridgesdk.ChatMessagesToPromptContext(messages)) -} - // hasAudioContent checks if the prompt contains audio content func hasAudioContent(messages []openai.ChatCompletionMessageParamUnion) bool { for _, msg := range messages { diff --git a/bridges/ai/streaming_responses_api.go b/bridges/ai/streaming_responses_api.go index 53921829..b05585a6 100644 --- a/bridges/ai/streaming_responses_api.go +++ b/bridges/ai/streaming_responses_api.go @@ -38,14 +38,16 @@ func (a *responsesTurnAdapter) TrackRoomRunStreaming() bool { func (a *responsesTurnAdapter) startInitialRound(ctx context.Context) (*ssestream.Stream[responses.ResponseStreamEventUnion], error) { if !a.initialized { - input := a.oc.convertToResponsesInput(a.messages, a.meta) - a.params = a.oc.buildResponsesAgentLoopParams(ctx, a.meta, input, false) + promptContext := ChatMessagesToPromptContext(a.messages) + input := PromptContextToResponsesInput(promptContext) + a.params = a.oc.buildResponsesAgentLoopParams(ctx, a.meta, promptContext.SystemPrompt, input, false) if len(a.params.Tools) > 0 { zerolog.Ctx(ctx).Debug().Int("count", len(a.params.Tools)).Msg("Added streaming turn tools") } if a.oc.isOpenRouterProvider() { ctx = WithPDFEngine(ctx, a.oc.effectivePDFEngine(a.meta)) } + a.state.baseSystemPrompt = promptContext.SystemPrompt a.initialized = true } stream := a.oc.api.Responses.NewStreaming(ctx, a.params) @@ -190,7 +192,8 @@ func (a *responsesTurnAdapter) ContinueAgentLoop(messages []openai.ChatCompletio return } a.messages = append(a.messages, messages...) - a.state.baseInput = append(a.state.baseInput, a.oc.convertToResponsesInput(messages, a.meta)...) + promptContext := ChatMessagesToPromptContext(messages) + a.state.baseInput = append(a.state.baseInput, PromptContextToResponsesInput(promptContext)...) a.hasFollowUp = true } diff --git a/bridges/ai/streaming_responses_input_test.go b/bridges/ai/streaming_responses_input_test.go index de9bcbb2..6e950286 100644 --- a/bridges/ai/streaming_responses_input_test.go +++ b/bridges/ai/streaming_responses_input_test.go @@ -9,44 +9,32 @@ import ( ) func TestConvertToResponsesInput_RolesAndToolOutput(t *testing.T) { - oc := &AIClient{} - messages := []openai.ChatCompletionMessageParamUnion{ - openai.DeveloperMessage("dev instructions"), openai.UserMessage("hello"), openai.ToolMessage("tool output", "call_123"), } - input := oc.convertToResponsesInput(messages, nil) - if len(input) != 3 { - t.Fatalf("expected 3 input items, got %d", len(input)) - } - - if input[0].OfMessage == nil { - t.Fatalf("expected developer message input, got nil") - } - if input[0].OfMessage.Role != responses.EasyInputMessageRoleDeveloper { - t.Fatalf("expected developer role, got %s", input[0].OfMessage.Role) + input := PromptContextToResponsesInput(ChatMessagesToPromptContext(messages)) + if len(input) != 2 { + t.Fatalf("expected 2 input items, got %d", len(input)) } - if input[1].OfMessage == nil || input[1].OfMessage.Role != responses.EasyInputMessageRoleUser { - t.Fatalf("expected user message input for item 2") + if input[0].OfMessage == nil || input[0].OfMessage.Role != responses.EasyInputMessageRoleUser { + t.Fatalf("expected user message input first") } - if input[2].OfFunctionCallOutput == nil { - t.Fatalf("expected function_call_output input for item 3") + if input[1].OfFunctionCallOutput == nil { + t.Fatalf("expected function_call_output input second") } - if input[2].OfFunctionCallOutput.CallID != "call_123" { - t.Fatalf("expected call_id call_123, got %s", input[2].OfFunctionCallOutput.CallID) + if input[1].OfFunctionCallOutput.CallID != "call_123" { + t.Fatalf("expected call_id call_123, got %s", input[1].OfFunctionCallOutput.CallID) } - if input[2].OfFunctionCallOutput.Output.OfString.Value != "tool output" { - t.Fatalf("expected tool output to match, got %q", input[2].OfFunctionCallOutput.Output.OfString.Value) + if input[1].OfFunctionCallOutput.Output.OfString.Value != "tool output" { + t.Fatalf("expected tool output to match, got %q", input[1].OfFunctionCallOutput.Output.OfString.Value) } } func TestConvertToResponsesInput_AssistantToolCalls(t *testing.T) { - oc := &AIClient{} - messages := []openai.ChatCompletionMessageParamUnion{{ OfAssistant: &openai.ChatCompletionAssistantMessageParam{ Content: openai.ChatCompletionAssistantMessageParamContentUnion{ @@ -65,7 +53,7 @@ func TestConvertToResponsesInput_AssistantToolCalls(t *testing.T) { }, }} - input := oc.convertToResponsesInput(messages, nil) + input := PromptContextToResponsesInput(ChatMessagesToPromptContext(messages)) if len(input) != 2 { t.Fatalf("expected 2 input items, got %d", len(input)) } diff --git a/bridges/ai/streaming_state.go b/bridges/ai/streaming_state.go index 460e4f00..ac688887 100644 --- a/bridges/ai/streaming_state.go +++ b/bridges/ai/streaming_state.go @@ -38,6 +38,7 @@ type streamingState struct { reasoningTokens int64 totalTokens int64 + baseSystemPrompt string baseInput responses.ResponseInputParam accumulated strings.Builder reasoning strings.Builder diff --git a/bridges/ai/subagent_announce.go b/bridges/ai/subagent_announce.go index ffc92219..cd73675d 100644 --- a/bridges/ai/subagent_announce.go +++ b/bridges/ai/subagent_announce.go @@ -11,8 +11,6 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/id" - - bridgesdk "github.com/beeper/agentremote/sdk" ) func formatDurationShort(valueMs int64) string { @@ -146,7 +144,7 @@ func (oc *AIClient) runSubagentCompletion( meta *PortalMetadata, prompt []openai.ChatCompletionMessageParamUnion, ) (bool, error) { - responseFn, logLabel := oc.selectAgentLoopRunFunc(meta, PromptContext{PromptContext: bridgesdk.ChatMessagesToPromptContext(prompt)}) + responseFn, logLabel := oc.selectAgentLoopRunFunc(meta, ChatMessagesToPromptContext(prompt)) return oc.responseWithRetry(ctx, nil, portal, meta, prompt, responseFn, logLabel) } diff --git a/bridges/ai/subagent_spawn.go b/bridges/ai/subagent_spawn.go index 1af0f146..ef66175d 100644 --- a/bridges/ai/subagent_spawn.go +++ b/bridges/ai/subagent_spawn.go @@ -351,7 +351,6 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P Timestamp: time.Now(), } setCanonicalTurnDataFromPromptMessages(userMessage.Metadata.(*MessageMetadata), promptTail(promptContext, 1)) - ensureCanonicalUserMessage(userMessage) if _, err := oc.UserLogin.Bridge.GetGhostByID(ctx, userMessage.SenderID); err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure user ghost before saving subagent task message") } diff --git a/bridges/ai/system_prompts.go b/bridges/ai/system_prompts.go index 17726a3c..5162772c 100644 --- a/bridges/ai/system_prompts.go +++ b/bridges/ai/system_prompts.go @@ -63,21 +63,62 @@ func (oc *AIClient) buildAdditionalSystemPrompts( return oc.additionalSystemMessages(ctx, portal, meta) } -func (oc *AIClient) buildSystemMessages( +func systemMessageText(messages []openai.ChatCompletionMessageParamUnion) string { + var parts []string + for _, msg := range messages { + if msg.OfSystem == nil { + continue + } + if text := strings.TrimSpace(msg.OfSystem.Content.OfString.Value); text != "" { + parts = append(parts, text) + continue + } + if len(msg.OfSystem.Content.OfArrayOfContentParts) == 0 { + continue + } + var lines []string + for _, part := range msg.OfSystem.Content.OfArrayOfContentParts { + if text := strings.TrimSpace(part.Text); text != "" { + lines = append(lines, text) + } + } + if len(lines) > 0 { + parts = append(parts, strings.Join(lines, "\n")) + } + } + return strings.TrimSpace(strings.Join(parts, "\n\n")) +} + +func (oc *AIClient) buildSystemPromptText( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, -) []openai.ChatCompletionMessageParamUnion { - var msgs []openai.ChatCompletionMessageParamUnion - systemPrompt := oc.effectiveAgentPrompt(ctx, portal, meta) - if systemPrompt == "" { - systemPrompt = oc.effectivePrompt(meta) +) string { + base := oc.effectiveAgentPrompt(ctx, portal, meta) + if base == "" { + base = oc.effectivePrompt(meta) + } + fragments := []string{base, systemMessageText(oc.buildAdditionalSystemPrompts(ctx, portal, meta))} + var parts []string + for _, fragment := range fragments { + if text := strings.TrimSpace(fragment); text != "" { + parts = append(parts, text) + } } - if systemPrompt != "" { - msgs = append(msgs, openai.SystemMessage(systemPrompt)) + return strings.TrimSpace(strings.Join(parts, "\n\n")) +} + +func (oc *AIClient) buildConversationSystemPromptText( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + includeGreeting bool, +) string { + base := oc.buildSystemPromptText(ctx, portal, meta) + if !includeGreeting { + return base } - msgs = append(msgs, oc.buildAdditionalSystemPrompts(ctx, portal, meta)...) - return msgs + return joinPromptFragments(sessionGreetingFragment(ctx, portal, meta, oc.log), base) } func (oc *AIClient) buildAdditionalSystemPromptsCore( diff --git a/bridges/ai/text_files.go b/bridges/ai/text_files.go index d9447ecc..c463e72a 100644 --- a/bridges/ai/text_files.go +++ b/bridges/ai/text_files.go @@ -5,6 +5,9 @@ import ( "encoding/binary" "errors" "fmt" + "os" + "os/exec" + "path/filepath" "regexp" "strings" "unicode" @@ -22,6 +25,7 @@ var ( ) const maxTextFileBytes = 5 * 1024 * 1024 +const maxPDFFileBytes = 50 * 1024 * 1024 var textFileMimeTypesMap = map[string]event.CapabilitySupportLevel{ "text/plain": event.CapLevelFullySupported, @@ -192,6 +196,43 @@ func (oc *AIClient) downloadTextFile(ctx context.Context, mediaURL string, encry return trimmed, truncated, nil } +func (oc *AIClient) downloadPDFFile(ctx context.Context, mediaURL string, encryptedFile *event.EncryptedFileInfo, mimeType string) (string, bool, error) { + data, _, err := oc.downloadMediaBytes(ctx, mediaURL, encryptedFile, maxPDFFileBytes, mimeType) + if err != nil { + return "", false, err + } + + tempDir, err := os.MkdirTemp("", "ai-bridge-pdf-*") + if err != nil { + return "", false, fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + inputPath := filepath.Join(tempDir, "input.pdf") + if err := os.WriteFile(inputPath, data, 0o600); err != nil { + return "", false, fmt.Errorf("write temp pdf: %w", err) + } + + cmd := exec.CommandContext(ctx, "pdftotext", "-layout", "-enc", "UTF-8", inputPath, "-") + output, err := cmd.Output() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + msg := strings.TrimSpace(string(exitErr.Stderr)) + if msg != "" { + return "", false, fmt.Errorf("pdftotext failed: %s", msg) + } + } + return "", false, fmt.Errorf("pdftotext failed: %w", err) + } + + text := strings.TrimSpace(string(output)) + if text == "" { + return "[No extractable text]", false, nil + } + trimmed, truncated := trimTextForModel(text) + return trimmed, truncated, nil +} + func buildTextFileMessage(caption string, hasUserCaption bool, filename string, mimeType string, content string, _ bool) string { if !hasUserCaption { caption = "" diff --git a/bridges/ai/tools_analyze_image.go b/bridges/ai/tools_analyze_image.go index 8b03aaaa..2c49f790 100644 --- a/bridges/ai/tools_analyze_image.go +++ b/bridges/ai/tools_analyze_image.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/beeper/agentremote/pkg/shared/media" - bridgesdk "github.com/beeper/agentremote/sdk" ) // executeAnalyzeImage analyzes an image with a custom prompt using vision capabilities. @@ -80,7 +79,7 @@ func executeAnalyzeImage(ctx context.Context, args map[string]any) (string, erro return "", errors.New("unsupported URL scheme, must be http://, https://, mxc://, or data URL") } - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( + ctxPrompt := UserPromptContext( PromptBlock{ Type: PromptBlockImage, ImageB64: imageB64, @@ -90,7 +89,7 @@ func executeAnalyzeImage(ctx context.Context, args map[string]any) (string, erro Type: PromptBlockText, Text: prompt, }, - )} + ) // Call the AI provider for vision analysis resp, err := btc.Client.provider.Generate(ctx, GenerateParams{ diff --git a/connector_builder.go b/connector_builder.go index f877d115..5e6920fb 100644 --- a/connector_builder.go +++ b/connector_builder.go @@ -14,8 +14,8 @@ type ConnectorSpec struct { AIRoomKind string Init func(*bridgev2.Bridge) - Start func(context.Context) error - Stop func(context.Context) + Start func(context.Context, *bridgev2.Bridge) error + Stop func(context.Context, *bridgev2.Bridge) Name func() bridgev2.BridgeName Config func() (example string, data any, upgrader configupgrade.Upgrader) @@ -62,14 +62,14 @@ func (c *ConnectorBase) Start(ctx context.Context) error { if c == nil || c.spec.Start == nil { return nil } - return c.spec.Start(ctx) + return c.spec.Start(ctx, c.br) } func (c *ConnectorBase) Stop(ctx context.Context) { if c == nil || c.spec.Stop == nil { return } - c.spec.Stop(ctx) + c.spec.Stop(ctx, c.br) } func (c *ConnectorBase) GetName() bridgev2.BridgeName { diff --git a/connector_builder_test.go b/connector_builder_test.go index 0c4d449d..135c155b 100644 --- a/connector_builder_test.go +++ b/connector_builder_test.go @@ -14,15 +14,24 @@ import ( func TestConnectorBaseHookOrder(t *testing.T) { var order []string + wantBridge := &bridgev2.Bridge{} conn := NewConnector(ConnectorSpec{ Init: func(*bridgev2.Bridge) { order = append(order, "init") }, - Start: func(context.Context) error { + Start: func(_ context.Context, got *bridgev2.Bridge) error { + if got != wantBridge { + t.Fatalf("expected start hook bridge %p, got %p", wantBridge, got) + } order = append(order, "start") return nil }, - Stop: func(context.Context) { order = append(order, "stop") }, + Stop: func(_ context.Context, got *bridgev2.Bridge) { + if got != wantBridge { + t.Fatalf("expected stop hook bridge %p, got %p", wantBridge, got) + } + order = append(order, "stop") + }, }) - conn.Init(nil) + conn.Init(wantBridge) if err := conn.Start(context.Background()); err != nil { t.Fatalf("start returned error: %v", err) } @@ -149,7 +158,7 @@ func TestConnectorStopCanDisconnectCachedClients(t *testing.T) { "b": &fakeClient{}, } conn := NewConnector(ConnectorSpec{ - Stop: func(context.Context) { + Stop: func(context.Context, *bridgev2.Bridge) { StopClients(&mu, &clients) }, }) diff --git a/sdk/connector.go b/sdk/connector.go index c6061593..2e59cf42 100644 --- a/sdk/connector.go +++ b/sdk/connector.go @@ -16,7 +16,6 @@ import ( // NewConnectorBase builds an SDK-backed connector base that can be embedded by custom bridges. func NewConnectorBase(cfg *Config) *agentremote.ConnectorBase { - var br *bridgev2.Bridge mu, clientsRef := cfg.ClientCacheMu, cfg.ClientCache if mu == nil { mu = &sync.Mutex{} @@ -66,23 +65,22 @@ func NewConnectorBase(cfg *Config) *agentremote.ConnectorBase { return agentremote.NewConnector(agentremote.ConnectorSpec{ ProtocolID: protocolID, Init: func(bridge *bridgev2.Bridge) { - br = bridge agentremote.EnsureClientMap(mu, clientsRef) if cfg.InitConnector != nil { cfg.InitConnector(bridge) } }, - Start: func(ctx context.Context) error { - registerCommands(br, cfg) + Start: func(ctx context.Context, bridge *bridgev2.Bridge) error { + registerCommands(bridge, cfg) if cfg.StartConnector != nil { - return cfg.StartConnector(ctx, br) + return cfg.StartConnector(ctx, bridge) } return nil }, - Stop: func(ctx context.Context) { + Stop: func(ctx context.Context, bridge *bridgev2.Bridge) { agentremote.StopClients(mu, clientsRef) if cfg.StopConnector != nil { - cfg.StopConnector(ctx, br) + cfg.StopConnector(ctx, bridge) } }, Name: func() bridgev2.BridgeName { diff --git a/sdk/connector_helpers.go b/sdk/connector_helpers.go index 139f5221..cf4f0798 100644 --- a/sdk/connector_helpers.go +++ b/sdk/connector_helpers.go @@ -32,6 +32,15 @@ func ApplyDefaultCommandPrefix(prefix *string, value string) { } } +// ResolveCommandPrefix returns the configured prefix when present, otherwise the +// bridge's declared default prefix without mutating configuration state. +func ResolveCommandPrefix(prefix string, fallback string) string { + if strings.TrimSpace(prefix) != "" { + return prefix + } + return fallback +} + // ApplyBoolDefault initializes a nil bool pointer to the provided value. func ApplyBoolDefault(target **bool, value bool) { if target == nil || *target != nil { diff --git a/sdk/connector_hooks_test.go b/sdk/connector_hooks_test.go index 4fc34874..cf6ed7ad 100644 --- a/sdk/connector_hooks_test.go +++ b/sdk/connector_hooks_test.go @@ -57,6 +57,7 @@ func TestNewConnectorBaseUsesHooksAndCustomClients(t *testing.T) { createCalled := 0 updateCalled := 0 afterLoadCalled := 0 + wantBridge := &bridgev2.Bridge{} cfg := &Config{ Name: "hooked", @@ -68,12 +69,25 @@ func TestNewConnectorBaseUsesHooksAndCustomClients(t *testing.T) { } return true, "" }, - InitConnector: func(*bridgev2.Bridge) { initCalled++ }, - StartConnector: func(context.Context, *bridgev2.Bridge) error { + InitConnector: func(got *bridgev2.Bridge) { + if got != wantBridge { + t.Fatalf("expected init bridge %p, got %p", wantBridge, got) + } + initCalled++ + }, + StartConnector: func(_ context.Context, got *bridgev2.Bridge) error { + if got != wantBridge { + t.Fatalf("expected start bridge %p, got %p", wantBridge, got) + } startCalled++ return nil }, - StopConnector: func(context.Context, *bridgev2.Bridge) { stopCalled++ }, + StopConnector: func(_ context.Context, got *bridgev2.Bridge) { + if got != wantBridge { + t.Fatalf("expected stop bridge %p, got %p", wantBridge, got) + } + stopCalled++ + }, MakeBrokenLogin: func(login *bridgev2.UserLogin, reason string) *agentremote.BrokenLoginClient { return agentremote.NewBrokenLoginClient(login, "custom:"+reason) }, @@ -89,7 +103,7 @@ func TestNewConnectorBaseUsesHooksAndCustomClients(t *testing.T) { } conn := NewConnectorBase(cfg) - conn.Init(nil) + conn.Init(wantBridge) if err := conn.Start(context.Background()); err != nil { t.Fatalf("start returned error: %v", err) } diff --git a/sdk/prompt_context.go b/sdk/prompt_context.go deleted file mode 100644 index 109e6f73..00000000 --- a/sdk/prompt_context.go +++ /dev/null @@ -1,557 +0,0 @@ -package sdk - -import ( - "fmt" - "slices" - "strings" - - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/packages/param" - "github.com/openai/openai-go/v3/responses" -) - -// PromptContext is the canonical provider-facing prompt representation. -type PromptContext struct { - SystemPrompt string - DeveloperPrompt string - Messages []PromptMessage -} - -func UserPromptContext(blocks ...PromptBlock) PromptContext { - return PromptContext{ - Messages: []PromptMessage{{ - Role: PromptRoleUser, - Blocks: slices.Clone(blocks), - }}, - } -} - -func PromptContextHasBlockType(ctx PromptContext, kinds ...PromptBlockType) bool { - if len(kinds) == 0 { - return false - } - allowed := make(map[PromptBlockType]struct{}, len(kinds)) - for _, kind := range kinds { - allowed[kind] = struct{}{} - } - for _, msg := range ctx.Messages { - for _, block := range msg.Blocks { - if _, ok := allowed[block.Type]; ok { - return true - } - } - } - return false -} - -// ChatMessagesToPromptContext converts chat-completions-shaped messages into the canonical prompt model. -func ChatMessagesToPromptContext(messages []openai.ChatCompletionMessageParamUnion) PromptContext { - var ctx PromptContext - AppendChatMessagesToPromptContext(&ctx, messages) - return ctx -} - -func AppendChatMessagesToPromptContext(ctx *PromptContext, messages []openai.ChatCompletionMessageParamUnion) { - if ctx == nil { - return - } - for _, msg := range messages { - appendChatMessageToPromptContext(ctx, msg) - } -} - -func appendChatMessageToPromptContext(ctx *PromptContext, msg openai.ChatCompletionMessageParamUnion) { - if ctx == nil { - return - } - switch { - case msg.OfSystem != nil: - AppendPromptText(&ctx.SystemPrompt, extractChatSystemText(msg.OfSystem.Content)) - case msg.OfDeveloper != nil: - AppendPromptText(&ctx.DeveloperPrompt, extractChatDeveloperText(msg.OfDeveloper.Content)) - case msg.OfUser != nil: - ctx.Messages = append(ctx.Messages, promptMessageFromChatUser(msg.OfUser)) - case msg.OfAssistant != nil: - ctx.Messages = append(ctx.Messages, promptMessageFromChatAssistant(msg.OfAssistant)) - case msg.OfTool != nil: - ctx.Messages = append(ctx.Messages, promptMessageFromChatTool(msg.OfTool)) - } -} - -func AppendPromptText(dst *string, text string) { - text = strings.TrimSpace(text) - if text == "" { - return - } - if *dst == "" { - *dst = text - return - } - *dst = strings.TrimSpace(*dst + "\n\n" + text) -} - -func promptMessageFromChatUser(msg *openai.ChatCompletionUserMessageParam) PromptMessage { - pm := PromptMessage{Role: PromptRoleUser} - if msg == nil { - return pm - } - if msg.Content.OfString.Value != "" { - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: msg.Content.OfString.Value, - }) - } - for _, part := range msg.Content.OfArrayOfContentParts { - pm.Blocks = append(pm.Blocks, promptBlockFromChatUserPart(part)...) - } - return pm -} - -func promptBlockFromChatUserPart(part openai.ChatCompletionContentPartUnionParam) []PromptBlock { - switch { - case part.OfText != nil: - return []PromptBlock{{Type: PromptBlockText, Text: part.OfText.Text}} - case part.OfImageURL != nil: - return []PromptBlock{{ - Type: PromptBlockImage, - ImageURL: part.OfImageURL.ImageURL.URL, - MimeType: inferPromptMimeTypeFromDataURL(part.OfImageURL.ImageURL.URL), - }} - case part.OfFile != nil: - return []PromptBlock{{ - Type: PromptBlockFile, - FileB64: part.OfFile.File.FileData.Value, - Filename: part.OfFile.File.Filename.Value, - }} - case part.OfInputAudio != nil: - return []PromptBlock{{ - Type: PromptBlockAudio, - AudioB64: part.OfInputAudio.InputAudio.Data, - AudioFormat: part.OfInputAudio.InputAudio.Format, - }} - default: - return nil - } -} - -func promptMessageFromChatAssistant(msg *openai.ChatCompletionAssistantMessageParam) PromptMessage { - pm := PromptMessage{Role: PromptRoleAssistant} - if msg == nil { - return pm - } - if msg.Content.OfString.Value != "" { - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: msg.Content.OfString.Value, - }) - } - for _, part := range msg.Content.OfArrayOfContentParts { - if part.OfText == nil { - continue - } - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: part.OfText.Text, - }) - } - for _, toolCall := range msg.ToolCalls { - if toolCall.OfFunction == nil { - continue - } - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockToolCall, - ToolCallID: toolCall.OfFunction.ID, - ToolName: toolCall.OfFunction.Function.Name, - ToolCallArguments: toolCall.OfFunction.Function.Arguments, - }) - } - return pm -} - -func promptMessageFromChatTool(msg *openai.ChatCompletionToolMessageParam) PromptMessage { - if msg == nil { - return PromptMessage{Role: PromptRoleToolResult} - } - pm := PromptMessage{ - Role: PromptRoleToolResult, - ToolCallID: msg.ToolCallID, - } - if msg.Content.OfString.Value != "" { - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: msg.Content.OfString.Value, - }) - } - for _, part := range msg.Content.OfArrayOfContentParts { - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: part.Text, - }) - } - return pm -} - -func extractChatSystemText(content openai.ChatCompletionSystemMessageParamContentUnion) string { - if content.OfString.Value != "" { - return content.OfString.Value - } - return joinChatText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartTextParam) string { - return part.Text - }) -} - -func extractChatDeveloperText(content openai.ChatCompletionDeveloperMessageParamContentUnion) string { - if content.OfString.Value != "" { - return content.OfString.Value - } - return joinChatText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartTextParam) string { - return part.Text - }) -} - -func joinChatText[T any](parts []T, extract func(T) string) string { - var values []string - for _, part := range parts { - if text := strings.TrimSpace(extract(part)); text != "" { - values = append(values, text) - } - } - return strings.Join(values, "\n") -} - -func inferPromptMimeTypeFromDataURL(value string) string { - value = strings.TrimSpace(value) - rest, ok := strings.CutPrefix(value, "data:") - if !ok { - return "" - } - value = rest - idx := strings.Index(value, ";") - if idx <= 0 { - return "" - } - return value[:idx] -} - -func BuildDataURL(mimeType, b64Data string) string { - return fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data) -} - -// resolveBlockImageURL returns the image URL for a prompt block, falling back -// to a base64 data URL when no explicit URL is provided. -func resolveBlockImageURL(block PromptBlock) string { - imageURL := strings.TrimSpace(block.ImageURL) - if imageURL == "" && block.ImageB64 != "" { - mimeType := block.MimeType - if mimeType == "" { - mimeType = "image/jpeg" - } - imageURL = BuildDataURL(mimeType, block.ImageB64) - } - return imageURL -} - -// PromptContextToResponsesInput converts the canonical prompt model into Responses input items. -func PromptContextToResponsesInput(ctx PromptContext) responses.ResponseInputParam { - var result responses.ResponseInputParam - - if strings.TrimSpace(ctx.DeveloperPrompt) != "" { - result = append(result, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleDeveloper, - Content: responses.EasyInputMessageContentUnionParam{ - OfString: openai.String(ctx.DeveloperPrompt), - }, - }, - }) - } - - for _, msg := range ctx.Messages { - switch msg.Role { - case PromptRoleUser: - var contentParts responses.ResponseInputMessageContentListParam - hasMultimodal := false - textContent := "" - - for _, block := range msg.Blocks { - switch block.Type { - case PromptBlockText: - if strings.TrimSpace(block.Text) == "" { - continue - } - if textContent != "" { - textContent += "\n" - } - textContent += block.Text - case PromptBlockImage: - imageURL := resolveBlockImageURL(block) - if imageURL == "" { - continue - } - hasMultimodal = true - contentParts = append(contentParts, responses.ResponseInputContentUnionParam{ - OfInputImage: &responses.ResponseInputImageParam{ - ImageURL: openai.String(imageURL), - Detail: responses.ResponseInputImageDetailAuto, - }, - }) - case PromptBlockFile: - fileData := strings.TrimSpace(block.FileB64) - fileURL := strings.TrimSpace(block.FileURL) - if fileData == "" && fileURL == "" { - continue - } - hasMultimodal = true - fileParam := &responses.ResponseInputFileParam{} - if fileData != "" { - fileParam.FileData = openai.String(fileData) - } - if fileURL != "" { - fileParam.FileURL = openai.String(fileURL) - } - if strings.TrimSpace(block.Filename) != "" { - fileParam.Filename = openai.String(block.Filename) - } - contentParts = append(contentParts, responses.ResponseInputContentUnionParam{ - OfInputFile: fileParam, - }) - case PromptBlockAudio, PromptBlockVideo: - // Unsupported in Responses API; caller should fall back to Chat Completions. - } - } - - if textContent != "" { - textPart := responses.ResponseInputContentUnionParam{ - OfInputText: &responses.ResponseInputTextParam{Text: textContent}, - } - contentParts = append([]responses.ResponseInputContentUnionParam{textPart}, contentParts...) - } - - if hasMultimodal && len(contentParts) > 0 { - result = append(result, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleUser, - Content: responses.EasyInputMessageContentUnionParam{ - OfInputItemContentList: contentParts, - }, - }, - }) - } else if textContent != "" { - result = append(result, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleUser, - Content: responses.EasyInputMessageContentUnionParam{ - OfString: openai.String(textContent), - }, - }, - }) - } - case PromptRoleAssistant: - textParts := make([]string, 0, len(msg.Blocks)) - for _, block := range msg.Blocks { - switch block.Type { - case PromptBlockText: - if strings.TrimSpace(block.Text) != "" { - textParts = append(textParts, block.Text) - } - case PromptBlockToolCall: - callID := strings.TrimSpace(block.ToolCallID) - name := strings.TrimSpace(block.ToolName) - args := strings.TrimSpace(block.ToolCallArguments) - if callID == "" || name == "" { - continue - } - if args == "" { - args = "{}" - } - result = appendAssistantTextItem(result, textParts) - textParts = textParts[:0] - result = append(result, responses.ResponseInputItemParamOfFunctionCall(args, callID, name)) - } - } - result = appendAssistantTextItem(result, textParts) - case PromptRoleToolResult: - callID := strings.TrimSpace(msg.ToolCallID) - output := strings.TrimSpace(msg.Text()) - if callID == "" || output == "" { - continue - } - result = append(result, responses.ResponseInputItemUnionParam{ - OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ - CallID: callID, - Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{ - OfString: openai.String(output), - }, - }, - }) - } - } - - return result -} - -func appendAssistantTextItem(result responses.ResponseInputParam, textParts []string) responses.ResponseInputParam { - text := strings.TrimSpace(strings.Join(textParts, "")) - if text == "" { - return result - } - return append(result, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleAssistant, - Content: responses.EasyInputMessageContentUnionParam{ - OfString: openai.String(text), - }, - }, - }) -} - -// PromptContextToChatCompletionMessages converts the canonical prompt model into Chat Completions messages. -func PromptContextToChatCompletionMessages(ctx PromptContext, supportsVideoURL bool) []openai.ChatCompletionMessageParamUnion { - result := make([]openai.ChatCompletionMessageParamUnion, 0, len(ctx.Messages)+2) - if strings.TrimSpace(ctx.SystemPrompt) != "" { - result = append(result, openai.SystemMessage(ctx.SystemPrompt)) - } - if strings.TrimSpace(ctx.DeveloperPrompt) != "" { - result = append(result, openai.ChatCompletionMessageParamUnion{ - OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ - OfString: openai.String(ctx.DeveloperPrompt), - }, - }, - }) - } - - for _, msg := range ctx.Messages { - switch msg.Role { - case PromptRoleUser: - if promptMessageHasMultimodal(msg) { - result = append(result, openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfArrayOfContentParts: promptBlocksToChatCompletionContentParts(msg.Blocks, supportsVideoURL), - }, - }, - }) - } else { - result = append(result, openai.UserMessage(msg.Text())) - } - case PromptRoleAssistant: - assistant := &openai.ChatCompletionAssistantMessageParam{ - Content: openai.ChatCompletionAssistantMessageParamContentUnion{ - OfString: openai.String(msg.Text()), - }, - } - for _, block := range msg.Blocks { - if block.Type != PromptBlockToolCall { - continue - } - args := strings.TrimSpace(block.ToolCallArguments) - if args == "" { - args = "{}" - } - assistant.ToolCalls = append(assistant.ToolCalls, openai.ChatCompletionMessageToolCallUnionParam{ - OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ - ID: block.ToolCallID, - Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ - Name: block.ToolName, - Arguments: args, - }, - Type: "function", - }, - }) - } - result = append(result, openai.ChatCompletionMessageParamUnion{OfAssistant: assistant}) - case PromptRoleToolResult: - result = append(result, openai.ToolMessage(msg.Text(), msg.ToolCallID)) - } - } - - return result -} - -func promptMessageHasMultimodal(msg PromptMessage) bool { - for _, block := range msg.Blocks { - switch block.Type { - case PromptBlockImage, PromptBlockFile, PromptBlockAudio, PromptBlockVideo: - return true - } - } - return false -} - -func promptBlocksToChatCompletionContentParts(blocks []PromptBlock, supportsVideoURL bool) []openai.ChatCompletionContentPartUnionParam { - result := make([]openai.ChatCompletionContentPartUnionParam, 0, len(blocks)) - for _, block := range blocks { - switch block.Type { - case PromptBlockText: - if strings.TrimSpace(block.Text) == "" { - continue - } - result = append(result, openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{Text: block.Text}, - }) - case PromptBlockImage: - imageURL := resolveBlockImageURL(block) - if imageURL == "" { - continue - } - result = append(result, openai.ChatCompletionContentPartUnionParam{ - OfImageURL: &openai.ChatCompletionContentPartImageParam{ - ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: imageURL}, - }, - }) - case PromptBlockFile: - file := openai.ChatCompletionContentPartFileFileParam{} - if strings.TrimSpace(block.FileB64) != "" { - file.FileData = param.NewOpt(block.FileB64) - } - if strings.TrimSpace(block.Filename) != "" { - file.Filename = param.NewOpt(block.Filename) - } - result = append(result, openai.ChatCompletionContentPartUnionParam{ - OfFile: &openai.ChatCompletionContentPartFileParam{File: file}, - }) - case PromptBlockAudio: - if strings.TrimSpace(block.AudioB64) == "" { - continue - } - format := strings.TrimSpace(block.AudioFormat) - if format == "" { - format = "mp3" - } - result = append(result, openai.ChatCompletionContentPartUnionParam{ - OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{ - InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ - Data: block.AudioB64, - Format: format, - }, - }, - }) - case PromptBlockVideo: - videoURL := strings.TrimSpace(block.VideoURL) - if videoURL == "" && block.VideoB64 != "" { - mimeType := strings.TrimSpace(block.MimeType) - if mimeType == "" { - mimeType = "video/mp4" - } - videoURL = BuildDataURL(mimeType, block.VideoB64) - } - if videoURL == "" { - continue - } - if supportsVideoURL { - result = append(result, param.Override[openai.ChatCompletionContentPartUnionParam](map[string]any{ - "type": "video_url", - "video_url": map[string]any{ - "url": videoURL, - }, - })) - } - } - } - return result -} - -func HasUnsupportedResponsesPromptContext(ctx PromptContext) bool { - return PromptContextHasBlockType(ctx, PromptBlockAudio, PromptBlockVideo) -} diff --git a/sdk/prompt_projection.go b/sdk/prompt_projection.go deleted file mode 100644 index 7096c887..00000000 --- a/sdk/prompt_projection.go +++ /dev/null @@ -1,284 +0,0 @@ -package sdk - -import ( - "encoding/json" - "fmt" - "strings" -) - -type PromptRole string - -const ( - PromptRoleUser PromptRole = "user" - PromptRoleAssistant PromptRole = "assistant" - PromptRoleToolResult PromptRole = "tool_result" -) - -type PromptBlockType string - -const ( - PromptBlockText PromptBlockType = "text" - PromptBlockImage PromptBlockType = "image" - PromptBlockFile PromptBlockType = "file" - PromptBlockThinking PromptBlockType = "thinking" - PromptBlockToolCall PromptBlockType = "tool_call" - PromptBlockAudio PromptBlockType = "audio" - PromptBlockVideo PromptBlockType = "video" -) - -type PromptBlock struct { - Type PromptBlockType - - Text string - - ImageURL string - ImageB64 string - MimeType string - - FileURL string - FileB64 string - Filename string - - ToolCallID string - ToolName string - ToolCallArguments string - - AudioB64 string - AudioFormat string - - VideoURL string - VideoB64 string -} - -type PromptMessage struct { - Role PromptRole - Blocks []PromptBlock - ToolCallID string - ToolName string - IsError bool -} - -func (m PromptMessage) Text() string { - var texts []string - for _, block := range m.Blocks { - switch block.Type { - case PromptBlockText, PromptBlockThinking: - if strings.TrimSpace(block.Text) != "" { - texts = append(texts, block.Text) - } - } - } - return strings.Join(texts, "\n") -} - -func PromptMessagesFromTurnData(td TurnData) []PromptMessage { - if td.Role == "" { - return nil - } - switch td.Role { - case "user": - msg := PromptMessage{Role: PromptRoleUser} - for _, part := range td.Parts { - switch normalizeTurnPartType(part.Type) { - case "text": - if strings.TrimSpace(part.Text) != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) - } - case "image": - if strings.TrimSpace(part.URL) != "" || promptExtraString(part.Extra, "imageB64") != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{ - Type: PromptBlockImage, - ImageURL: part.URL, - ImageB64: promptExtraString(part.Extra, "imageB64"), - MimeType: part.MediaType, - }) - } - case "file": - if strings.TrimSpace(part.URL) != "" || strings.TrimSpace(part.Filename) != "" || promptExtraString(part.Extra, "fileB64") != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{ - Type: PromptBlockFile, - FileURL: part.URL, - FileB64: promptExtraString(part.Extra, "fileB64"), - Filename: part.Filename, - MimeType: part.MediaType, - }) - } - case "audio": - if promptExtraString(part.Extra, "audioB64") != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{ - Type: PromptBlockAudio, - AudioB64: promptExtraString(part.Extra, "audioB64"), - AudioFormat: promptExtraString(part.Extra, "audioFormat"), - MimeType: part.MediaType, - }) - } - case "video": - if strings.TrimSpace(part.URL) != "" || promptExtraString(part.Extra, "videoB64") != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{ - Type: PromptBlockVideo, - VideoURL: part.URL, - VideoB64: promptExtraString(part.Extra, "videoB64"), - MimeType: part.MediaType, - }) - } - } - } - if len(msg.Blocks) == 0 { - return nil - } - return []PromptMessage{msg} - case "assistant": - assistant := PromptMessage{Role: PromptRoleAssistant} - var results []PromptMessage - for _, part := range td.Parts { - switch normalizeTurnPartType(part.Type) { - case "text": - if strings.TrimSpace(part.Text) != "" { - assistant.Blocks = append(assistant.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) - } - case "reasoning": - text := strings.TrimSpace(part.Reasoning) - if text == "" { - text = strings.TrimSpace(part.Text) - } - if text != "" { - assistant.Blocks = append(assistant.Blocks, PromptBlock{Type: PromptBlockThinking, Text: text}) - } - case "tool": - if strings.TrimSpace(part.ToolCallID) != "" && strings.TrimSpace(part.ToolName) != "" { - assistant.Blocks = append(assistant.Blocks, PromptBlock{ - Type: PromptBlockToolCall, - ToolCallID: part.ToolCallID, - ToolName: part.ToolName, - ToolCallArguments: CanonicalToolArguments(part.Input), - }) - } - outputText := strings.TrimSpace(FormatCanonicalValue(part.Output)) - if outputText == "" { - outputText = strings.TrimSpace(part.ErrorText) - } - if outputText == "" && part.State == "output-denied" { - outputText = "Denied by user" - } - if strings.TrimSpace(part.ToolCallID) != "" && outputText != "" { - results = append(results, PromptMessage{ - Role: PromptRoleToolResult, - ToolCallID: part.ToolCallID, - ToolName: part.ToolName, - IsError: strings.TrimSpace(part.ErrorText) != "", - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: outputText, - }}, - }) - } - } - } - if len(assistant.Blocks) == 0 && len(results) == 0 { - return nil - } - out := make([]PromptMessage, 0, 1+len(results)) - if len(assistant.Blocks) > 0 { - out = append(out, assistant) - } - out = append(out, results...) - return out - default: - return nil - } -} - -func TurnDataFromUserPromptMessages(messages []PromptMessage) (TurnData, bool) { - if len(messages) == 0 { - return TurnData{}, false - } - msg := messages[0] - if msg.Role != PromptRoleUser { - return TurnData{}, false - } - td := TurnData{Role: "user"} - td.Parts = make([]TurnPart, 0, len(msg.Blocks)) - for _, block := range msg.Blocks { - switch block.Type { - case PromptBlockText: - if strings.TrimSpace(block.Text) != "" { - td.Parts = append(td.Parts, TurnPart{Type: "text", Text: block.Text}) - } - case PromptBlockImage: - if strings.TrimSpace(block.ImageURL) != "" || strings.TrimSpace(block.ImageB64) != "" { - part := TurnPart{Type: "image", URL: block.ImageURL, MediaType: block.MimeType} - if strings.TrimSpace(block.ImageB64) != "" { - part.Extra = map[string]any{"imageB64": block.ImageB64} - } - td.Parts = append(td.Parts, part) - } - case PromptBlockFile: - if strings.TrimSpace(block.FileURL) != "" || strings.TrimSpace(block.FileB64) != "" || strings.TrimSpace(block.Filename) != "" { - part := TurnPart{ - Type: "file", - URL: block.FileURL, - Filename: block.Filename, - MediaType: block.MimeType, - } - if strings.TrimSpace(block.FileB64) != "" { - part.Extra = map[string]any{"fileB64": block.FileB64} - } - td.Parts = append(td.Parts, part) - } - case PromptBlockAudio: - if strings.TrimSpace(block.AudioB64) != "" { - td.Parts = append(td.Parts, TurnPart{ - Type: "audio", - MediaType: block.MimeType, - Extra: map[string]any{ - "audioB64": block.AudioB64, - "audioFormat": block.AudioFormat, - }, - }) - } - case PromptBlockVideo: - if strings.TrimSpace(block.VideoURL) != "" || strings.TrimSpace(block.VideoB64) != "" { - part := TurnPart{ - Type: "video", - URL: block.VideoURL, - MediaType: block.MimeType, - } - if strings.TrimSpace(block.VideoB64) != "" { - part.Extra = map[string]any{"videoB64": block.VideoB64} - } - td.Parts = append(td.Parts, part) - } - } - } - return td, len(td.Parts) > 0 -} - -func CanonicalToolArguments(raw any) string { - if value := strings.TrimSpace(FormatCanonicalValue(raw)); value != "" { - return value - } - return "{}" -} - -func FormatCanonicalValue(raw any) string { - switch typed := raw.(type) { - case nil: - return "" - case string: - return typed - default: - data, err := json.Marshal(typed) - if err != nil { - return fmt.Sprint(typed) - } - return string(data) - } -} - -func promptExtraString(extra map[string]any, key string) string { - if len(extra) == 0 { - return "" - } - value, _ := extra[key].(string) - return strings.TrimSpace(value) -} diff --git a/sdk/turn_data_test.go b/sdk/turn_data_test.go index 28bd4c3d..98986651 100644 --- a/sdk/turn_data_test.go +++ b/sdk/turn_data_test.go @@ -109,66 +109,3 @@ func TestBuildTurnDataFromUIMessageMergesRuntimeState(t *testing.T) { t.Fatalf("expected source-url part, got %#v", td.Parts) } } - -func TestPromptMessagesFromTurnData(t *testing.T) { - td := TurnData{ - Role: "assistant", - Parts: []TurnPart{ - {Type: "text", Text: "hello"}, - {Type: "reasoning", Reasoning: "thinking"}, - {Type: "tool", ToolCallID: "tool-1", ToolName: "search", Input: map[string]any{"q": "matrix"}, Output: map[string]any{"done": true}}, - }, - } - - messages := PromptMessagesFromTurnData(td) - if len(messages) != 2 { - t.Fatalf("expected assistant + tool result, got %#v", messages) - } - if messages[0].Role != PromptRoleAssistant { - t.Fatalf("unexpected assistant role %#v", messages[0]) - } - if messages[1].Role != PromptRoleToolResult || messages[1].ToolCallID != "tool-1" { - t.Fatalf("unexpected tool result %#v", messages[1]) - } -} - -func TestTurnDataFromUserPromptMessagesPreservesInlineMedia(t *testing.T) { - messages := []PromptMessage{{ - Role: PromptRoleUser, - Blocks: []PromptBlock{ - {Type: PromptBlockText, Text: "describe these attachments"}, - {Type: PromptBlockImage, ImageB64: "aW1hZ2U=", MimeType: "image/png"}, - {Type: PromptBlockFile, FileB64: "data:application/pdf;base64,cGRm", Filename: "doc.pdf", MimeType: "application/pdf"}, - {Type: PromptBlockAudio, AudioB64: "YXVkaW8=", AudioFormat: "mp3", MimeType: "audio/mpeg"}, - {Type: PromptBlockVideo, VideoB64: "dmlkZW8=", MimeType: "video/mp4"}, - }, - }} - - td, ok := TurnDataFromUserPromptMessages(messages) - if !ok { - t.Fatal("expected user prompt messages to produce turn data") - } - if len(td.Parts) != 5 { - t.Fatalf("expected 5 parts, got %#v", td.Parts) - } - - roundTrip := PromptMessagesFromTurnData(td) - if len(roundTrip) != 1 || len(roundTrip[0].Blocks) != 5 { - t.Fatalf("expected one user message with 5 blocks, got %#v", roundTrip) - } - if got := roundTrip[0].Blocks[1].ImageB64; got != "aW1hZ2U=" { - t.Fatalf("expected inline image to round-trip, got %#v", roundTrip[0].Blocks[1]) - } - if got := roundTrip[0].Blocks[2].FileB64; got != "data:application/pdf;base64,cGRm" { - t.Fatalf("expected inline file to round-trip, got %#v", roundTrip[0].Blocks[2]) - } - if got := roundTrip[0].Blocks[3].AudioB64; got != "YXVkaW8=" { - t.Fatalf("expected inline audio to round-trip, got %#v", roundTrip[0].Blocks[3]) - } - if got := roundTrip[0].Blocks[3].AudioFormat; got != "mp3" { - t.Fatalf("expected audio format to round-trip, got %#v", roundTrip[0].Blocks[3]) - } - if got := roundTrip[0].Blocks[4].VideoB64; got != "dmlkZW8=" { - t.Fatalf("expected inline video to round-trip, got %#v", roundTrip[0].Blocks[4]) - } -} diff --git a/sdk/turn_snapshot.go b/sdk/turn_snapshot.go index 005d5d7b..821c02ce 100644 --- a/sdk/turn_snapshot.go +++ b/sdk/turn_snapshot.go @@ -10,7 +10,6 @@ import ( type TurnSnapshot struct { TurnData TurnData UIMessage map[string]any - PromptMessages []PromptMessage Body string ThinkingContent string ToolCalls []agentremote.ToolCallMetadata @@ -25,7 +24,6 @@ func SnapshotFromTurnData(td TurnData, toolType string) TurnSnapshot { return TurnSnapshot{ TurnData: td.Clone(), UIMessage: UIMessageFromTurnData(td), - PromptMessages: PromptMessagesFromTurnData(td), Body: TurnText(td), ThinkingContent: TurnReasoningText(td), ToolCalls: TurnToolCalls(td, toolType), From ca6763cf18ff7681ee8623611e6922fbeb3ddc72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 21:42:09 +0200 Subject: [PATCH 02/23] Refactor agent defaults and prompt builders Centralize agent defaults and refactor prompt construction. Move provider-specific settings into Agents.Defaults (model selections, PDF engine, compaction) and initialize them in runtime defaults; remove the old ProvidersConfig/ProviderConfig fields. Rework tools config shape (introduce tools.web and tools.links) and migrate link-preview/config paths accordingly. Introduce ModelSelectionConfig and use it for default model selection. Consolidate multiple prompt-building helpers into unified turn-based builders (buildCurrentTurnWithLinks, buildMediaTurnContext, buildHeartbeatTurnContext, buildPromptContextForTurn) and update callers across handlers and internal dispatch. Provide default PDF engine fallback ("mistral-ocr") via helper methods. Apply related test updates and remove the example integrations config. These changes simplify configuration, centralize defaults, and reduce duplication in prompt assembly. --- bridges/ai/client.go | 169 +++----- bridges/ai/connector.go | 12 +- bridges/ai/handlematrix.go | 8 +- bridges/ai/heartbeat_execute.go | 17 +- bridges/ai/integrations_config.go | 212 +++++----- bridges/ai/integrations_example-config.yaml | 361 +++++------------- bridges/ai/internal_dispatch.go | 2 +- bridges/ai/login.go | 16 +- bridges/ai/media_understanding_runner.go | 5 +- .../media_understanding_runner_openai_test.go | 12 +- bridges/ai/prompt_builder.go | 111 ++++++ bridges/ai/response_retry_test.go | 4 +- bridges/ai/runtime_compaction_adapter.go | 6 +- bridges/ai/runtime_defaults_test.go | 34 +- bridges/ai/status_text.go | 4 +- bridges/ai/streaming_request_tools_test.go | 4 +- bridges/ai/streaming_tool_selection_test.go | 8 +- bridges/ai/subagent_spawn.go | 2 +- bridges/ai/token_resolver.go | 19 +- .../ai/tool_availability_configured_test.go | 10 +- bridges/ai/tool_configured.go | 8 +- bridges/ai/tools_tts_test.go | 6 +- 22 files changed, 459 insertions(+), 571 deletions(-) diff --git a/bridges/ai/client.go b/bridges/ai/client.go index 2001fdc8..54386239 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -485,16 +485,14 @@ func initProviderForLogin(key string, meta *UserLoginMetadata, connector *OpenAI } switch meta.Provider { case ProviderOpenRouter: - pdfEngine := connector.Config.Providers.OpenRouter.DefaultPDFEngine - return initOpenRouterProvider(key, connector.resolveOpenRouterBaseURL(), "", pdfEngine, ProviderOpenRouter, log) + return initOpenRouterProvider(key, connector.resolveOpenRouterBaseURL(), "", connector.defaultPDFEngineForInit(), ProviderOpenRouter, log) case ProviderMagicProxy: baseURL := normalizeProxyBaseURL(meta.BaseURL) if baseURL == "" { return nil, errors.New("magic proxy base_url is required") } - pdfEngine := connector.Config.Providers.OpenRouter.DefaultPDFEngine - return initOpenRouterProvider(key, joinProxyPath(baseURL, "/openrouter/v1"), "", pdfEngine, ProviderMagicProxy, log) + return initOpenRouterProvider(key, joinProxyPath(baseURL, "/openrouter/v1"), "", connector.defaultPDFEngineForInit(), ProviderMagicProxy, log) case ProviderOpenAI: openaiURL := connector.resolveOpenAIBaseURL() @@ -509,6 +507,15 @@ func initProviderForLogin(key string, meta *UserLoginMetadata, connector *OpenAI } } +func (oc *OpenAIConnector) defaultPDFEngineForInit() string { + if oc != nil && oc.Config.Agents != nil && oc.Config.Agents.Defaults != nil { + if engine := strings.TrimSpace(oc.Config.Agents.Defaults.PDFEngine); engine != "" { + return engine + } + } + return "mistral-ocr" +} + // initOpenRouterProvider creates an OpenRouter-compatible provider with PDF support. func initOpenRouterProvider(key, url, userID, pdfEngine, providerName string, log zerolog.Logger) (*OpenAIProvider, error) { log.Info(). @@ -792,9 +799,9 @@ func (oc *AIClient) processPendingQueue(ctx context.Context, roomID id.RoomID) { } switch item.pending.Type { case pendingTypeText: - promptContext, err = oc.buildContextWithLinkContext(promptCtx, item.pending.Portal, metaSnapshot, prompt, item.rawEventContent, eventID) + promptContext, err = oc.buildCurrentTurnWithLinks(promptCtx, item.pending.Portal, metaSnapshot, prompt, item.rawEventContent, eventID) case pendingTypeImage, pendingTypePDF, pendingTypeAudio, pendingTypeVideo: - promptContext, err = oc.buildContextWithMedia(promptCtx, item.pending.Portal, metaSnapshot, item.pending.MessageBody, item.pending.MediaURL, item.pending.MimeType, item.pending.EncryptedFile, item.pending.Type, eventID) + promptContext, err = oc.buildMediaTurnContext(promptCtx, item.pending.Portal, metaSnapshot, item.pending.MessageBody, item.pending.MediaURL, item.pending.MimeType, item.pending.EncryptedFile, item.pending.Type, eventID) case pendingTypeRegenerate: promptContext, err = oc.buildContextForRegenerate(promptCtx, item.pending.Portal, metaSnapshot, item.pending.MessageBody, item.pending.SourceEventID) case pendingTypeEditRegenerate: @@ -1147,17 +1154,15 @@ func (oc *AIClient) defaultModelForProvider() string { if loginMeta == nil { return DefaultModelOpenRouter } - providers := oc.connector.Config.Providers - switch loginMeta.Provider { case ProviderOpenAI: - if providers.OpenAI.DefaultModel != "" { - return providers.OpenAI.DefaultModel + if configured := strings.TrimSpace(oc.defaultModelSelection(ProviderOpenAI).Primary); configured != "" { + return configured } return DefaultModelOpenAI case ProviderOpenRouter, ProviderMagicProxy: - if providers.OpenRouter.DefaultModel != "" { - return providers.OpenRouter.DefaultModel + if configured := strings.TrimSpace(oc.defaultModelSelection(ProviderOpenRouter).Primary); configured != "" { + return configured } return DefaultModelOpenRouter default: @@ -1165,6 +1170,23 @@ func (oc *AIClient) defaultModelForProvider() string { } } +func (oc *AIClient) defaultModelSelection(provider string) ModelSelectionConfig { + if oc == nil || oc.connector == nil || oc.connector.Config.Agents == nil || oc.connector.Config.Agents.Defaults == nil || oc.connector.Config.Agents.Defaults.Model == nil { + return ModelSelectionConfig{} + } + selection := *oc.connector.Config.Agents.Defaults.Model + if strings.TrimSpace(selection.Primary) != "" { + return selection + } + switch strings.ToLower(strings.TrimSpace(provider)) { + case ProviderOpenAI: + selection.Primary = DefaultModelOpenAI + case ProviderOpenRouter, ProviderMagicProxy: + selection.Primary = DefaultModelOpenRouter + } + return selection +} + // effectivePrompt returns the base system prompt to use for non-agent rooms. func (oc *AIClient) effectivePrompt(meta *PortalMetadata) string { base := oc.connector.Config.DefaultSystemPrompt @@ -1215,8 +1237,8 @@ func (oc *AIClient) profilePromptSupplement() string { func getLinkPreviewConfig(connectorConfig *Config) LinkPreviewConfig { config := DefaultLinkPreviewConfig() - if connectorConfig.LinkPreviews != nil { - cfg := connectorConfig.LinkPreviews + if connectorConfig.Tools.Links != nil { + cfg := connectorConfig.Tools.Links // Apply explicit settings only if they differ from zero values if !cfg.Enabled { config.Enabled = cfg.Enabled @@ -1489,24 +1511,24 @@ func (oc *AIClient) isGroupChat(ctx context.Context, portal *bridgev2.Portal) bo return len(members) > 2 } +func (oc *AIClient) defaultPDFEngine() string { + if oc != nil && oc.connector != nil && oc.connector.Config.Agents != nil && + oc.connector.Config.Agents.Defaults != nil { + if engine := strings.TrimSpace(oc.connector.Config.Agents.Defaults.PDFEngine); engine != "" { + return engine + } + } + return "mistral-ocr" +} + // effectivePDFEngine returns the PDF engine to use for the given portal. -// Priority: room-level PDFConfig > provider-level config > default "mistral-ocr" +// Priority: room-level PDFConfig > agent defaults > default "mistral-ocr" func (oc *AIClient) effectivePDFEngine(meta *PortalMetadata) string { // Room-level override if meta != nil && meta.PDFConfig != nil && meta.PDFConfig.Engine != "" { return meta.PDFConfig.Engine } - - // Provider-level config - loginMeta := loginMetadata(oc.UserLogin) - switch loginMeta.Provider { - case ProviderOpenRouter, ProviderMagicProxy: - if engine := oc.connector.Config.Providers.OpenRouter.DefaultPDFEngine; engine != "" { - return engine - } - } - - return "mistral-ocr" // Default + return oc.defaultPDFEngine() } // validateModel checks if a model is available for this user @@ -1707,18 +1729,6 @@ func (oc *AIClient) buildBaseContext( return promptContext, nil } -func (oc *AIClient) buildBasePrompt( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, -) ([]openai.ChatCompletionMessageParamUnion, error) { - promptContext, err := oc.buildBaseContext(ctx, portal, meta) - if err != nil { - return nil, err - } - return oc.promptContextToDispatchMessages(ctx, portal, meta, promptContext), nil -} - func (oc *AIClient) applyAbortHint(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, body string) string { if meta == nil || !meta.AbortedLastRun { return body @@ -1772,30 +1782,6 @@ func (oc *AIClient) prepareInboundPromptContext( UntrustedPrefix: untrustedPrefix, }, nil } -func (oc *AIClient) buildContextWithLinkContext( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, - latest string, - rawEventContent map[string]any, - eventID id.EventID, -) (PromptContext, error) { - promptContext, text, err := oc.buildCurrentTurnText(ctx, portal, meta, latest, eventID, currentTurnTextOptions{ - rawEventContent: rawEventContent, - includeLinkScope: true, - }) - if err != nil { - return PromptContext{}, err - } - promptContext.Messages = append(promptContext.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: text, - }}, - }) - return promptContext, nil -} // buildLinkContext extracts URLs from the message, fetches previews, and returns formatted context. func (oc *AIClient) buildLinkContext(ctx context.Context, message string, rawEventContent map[string]any) string { @@ -1857,8 +1843,8 @@ func (oc *AIClient) buildLinkContext(ctx context.Context, message string, rawEve return FormatPreviewsForContext(allPreviews, config.MaxContentChars) } -// buildPromptWithMedia builds a prompt with media content. -func (oc *AIClient) buildContextWithMedia( +// buildMediaTurnContext builds a prompt turn with media content. +func (oc *AIClient) buildMediaTurnContext( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, @@ -1869,54 +1855,15 @@ func (oc *AIClient) buildContextWithMedia( mediaType pendingMessageType, eventID id.EventID, ) (PromptContext, error) { - appendBlocks := make([]string, 0, 1) - blocks := make([]PromptBlock, 0, 2) - - switch mediaType { - case pendingTypeImage: - b64Data, actualMimeType, err := oc.downloadMediaBase64(ctx, mediaURL, encryptedFile, 20, mimeType) // 20MB limit for images - if err != nil { - return PromptContext{}, fmt.Errorf("failed to download image: %w", err) - } - blocks = append(blocks, PromptBlock{ - Type: PromptBlockImage, - ImageB64: b64Data, - MimeType: actualMimeType, - }) - - case pendingTypePDF: - content, truncated, err := oc.downloadPDFFile(ctx, mediaURL, encryptedFile, mimeType) - if err != nil { - return PromptContext{}, fmt.Errorf("failed to download PDF: %w", err) - } - filename := resolveMediaFileName("document.pdf", "pdf", mediaURL) - appendBlocks = append(appendBlocks, buildTextFileMessage("", false, filename, "application/pdf", content, truncated)) - - case pendingTypeAudio: - return PromptContext{}, fmt.Errorf("audio attachments must be preprocessed into text before prompt assembly") - - case pendingTypeVideo: - return PromptContext{}, fmt.Errorf("video attachments must be preprocessed into text before prompt assembly") - - default: - return PromptContext{}, fmt.Errorf("unsupported media type: %s", mediaType) - } - - promptContext, text, err := oc.buildCurrentTurnText(ctx, portal, meta, caption, eventID, currentTurnTextOptions{ + return oc.buildPromptContextForTurn(ctx, portal, meta, caption, eventID, currentTurnPromptOptions{ includeLinkScope: true, - append: appendBlocks, - }) - if err != nil { - return PromptContext{}, err - } - if strings.TrimSpace(text) != "" { - blocks = append([]PromptBlock{{Type: PromptBlockText, Text: text}}, blocks...) - } - promptContext.Messages = append(promptContext.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: blocks, + attachment: &turnAttachmentOptions{ + mediaURL: mediaURL, + mimeType: mimeType, + encryptedFile: encryptedFile, + mediaType: mediaType, + }, }) - return promptContext, nil } // buildPromptUpToMessage builds a prompt including messages up to and including the specified message @@ -2133,7 +2080,7 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { } // Build prompt with combined body - promptContext, err := oc.buildContextWithLinkContext(statusCtx, last.Portal, last.Meta, combinedBody, rawEventContent, last.Event.ID) + promptContext, err := oc.buildCurrentTurnWithLinks(statusCtx, last.Portal, last.Meta, combinedBody, rawEventContent, last.Event.ID) if err != nil { oc.loggerForContext(ctx).Err(err).Msg("Failed to build prompt for debounced messages") oc.notifyMatrixSendFailure(statusCtx, last.Portal, last.Event, err) diff --git a/bridges/ai/connector.go b/bridges/ai/connector.go index ed9238bf..f6d822c7 100644 --- a/bridges/ai/connector.go +++ b/bridges/ai/connector.go @@ -55,10 +55,16 @@ func (oc *OpenAIConnector) applyRuntimeDefaults() { oc.Config.ModelCacheDuration = 6 * time.Hour } bridgesdk.ApplyDefaultCommandPrefix(&oc.Config.Bridge.CommandPrefix, "!ai") - if oc.Config.Pruning == nil { - oc.Config.Pruning = airuntime.DefaultPruningConfig() + if oc.Config.Agents == nil { + oc.Config.Agents = &AgentsConfig{} + } + if oc.Config.Agents.Defaults == nil { + oc.Config.Agents.Defaults = &AgentDefaultsConfig{} + } + if oc.Config.Agents.Defaults.Compaction == nil { + oc.Config.Agents.Defaults.Compaction = airuntime.DefaultPruningConfig() } else { - oc.Config.Pruning = airuntime.ApplyPruningDefaults(oc.Config.Pruning) + oc.Config.Agents.Defaults.Compaction = airuntime.ApplyPruningDefaults(oc.Config.Agents.Defaults.Compaction) } } diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index 04a360fc..bc4a9a4c 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -258,7 +258,7 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri eventID = msg.Event.ID } - promptContext, err := oc.buildContextWithLinkContext(runCtx, portal, runMeta, body, rawEventContent, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(runCtx, portal, runMeta, body, rawEventContent, eventID) if err != nil { return nil, messageSendStatusError(err, "Couldn't prepare the message. Try again.", "") } @@ -604,7 +604,7 @@ func (oc *AIClient) handleMediaMessage( inboundCtx := oc.buildMatrixInboundContext(portal, msg.Event, rawBody, senderName, roomName, isGroup) promptCtx := withInboundContext(ctx, inboundCtx) body := oc.buildMatrixInboundBody(ctx, portal, meta, msg.Event, rawBody, senderName, roomName, isGroup) - promptContext, err := oc.buildContextWithLinkContext(promptCtx, portal, meta, body, nil, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(promptCtx, portal, meta, body, nil, eventID) if err != nil { return nil, messageSendStatusError(err, "Couldn't prepare the message. Try again.", "") } @@ -713,7 +713,7 @@ func (oc *AIClient) handleMediaMessage( captionForPrompt := oc.buildMatrixInboundBody(ctx, portal, meta, msg.Event, caption, senderName, roomName, isGroup) captionInboundCtx := oc.buildMatrixInboundContext(portal, msg.Event, caption, senderName, roomName, isGroup) promptCtx := withInboundContext(ctx, captionInboundCtx) - promptContext, err := oc.buildContextWithMedia(promptCtx, portal, meta, captionForPrompt, string(mediaURL), mimeType, encryptedFile, config.msgType, eventID) + promptContext, err := oc.buildMediaTurnContext(promptCtx, portal, meta, captionForPrompt, string(mediaURL), mimeType, encryptedFile, config.msgType, eventID) if err != nil { return nil, messageSendStatusError(err, "Couldn't prepare the media message. Try again.", "") } @@ -870,7 +870,7 @@ func (oc *AIClient) handleTextFileMessage( inboundCtx := oc.buildMatrixInboundContext(portal, msg.Event, combined, senderName, roomName, isGroup) promptCtx := withInboundContext(ctx, inboundCtx) - promptContext, err := oc.buildContextWithLinkContext(promptCtx, portal, meta, combined, nil, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(promptCtx, portal, meta, combined, nil, eventID) if err != nil { return nil, messageSendStatusError(err, "Couldn't prepare the message. Try again.", "") } diff --git a/bridges/ai/heartbeat_execute.go b/bridges/ai/heartbeat_execute.go index 27d9a57c..ae13e85f 100644 --- a/bridges/ai/heartbeat_execute.go +++ b/bridges/ai/heartbeat_execute.go @@ -179,7 +179,7 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, } } - promptContext, err := oc.buildContextWithHeartbeat(context.Background(), sessionPortal, promptMeta, prompt) + promptContext, err := oc.buildHeartbeatTurnContext(context.Background(), sessionPortal, promptMeta, prompt) if err != nil { oc.log.Warn().Str("agent_id", agentID).Str("reason", reason).Err(err).Msg("Heartbeat failed to build prompt") indicator := (*HeartbeatIndicatorType)(nil) @@ -282,21 +282,6 @@ func systemEventsOwnerKey(oc *AIClient) string { return string(oc.UserLogin.Bridge.DB.BridgeID) + "|" + string(oc.UserLogin.ID) } -func (oc *AIClient) buildContextWithHeartbeat(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, prompt string) (PromptContext, error) { - base, text, err := oc.buildCurrentTurnText(ctx, portal, meta, prompt, "", currentTurnTextOptions{}) - if err != nil { - return PromptContext{}, err - } - base.Messages = append(base.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: text, - }}, - }) - return base, nil -} - func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *HeartbeatConfig) (*bridgev2.Portal, string, error) { session := "" if heartbeat != nil && heartbeat.Session != nil { diff --git a/bridges/ai/integrations_config.go b/bridges/ai/integrations_config.go index ad7298b7..db359427 100644 --- a/bridges/ai/integrations_config.go +++ b/bridges/ai/integrations_config.go @@ -22,7 +22,6 @@ var exampleNetworkConfig string // the `network:` block in the main bridge config. type Config struct { Beeper BeeperConfig `yaml:"beeper"` - Providers ProvidersConfig `yaml:"providers"` Models *ModelsConfig `yaml:"models"` Bridge BridgeConfig `yaml:"bridge"` Tools ToolProvidersConfig `yaml:"tools"` @@ -38,12 +37,6 @@ type Config struct { DefaultSystemPrompt string `yaml:"default_system_prompt"` ModelCacheDuration time.Duration `yaml:"model_cache_duration"` - // Context pruning configuration - Pruning *airuntime.PruningConfig `yaml:"pruning"` - - // Link preview configuration - LinkPreviews *LinkPreviewConfig `yaml:"link_previews"` - // Inbound message processing configuration Inbound *InboundConfig `yaml:"inbound"` @@ -117,6 +110,12 @@ type AgentDefaultsConfig struct { SkipBootstrap bool `yaml:"skip_bootstrap"` BootstrapMaxChars int `yaml:"bootstrap_max_chars"` TimeoutSeconds int `yaml:"timeout_seconds"` + Model *ModelSelectionConfig `yaml:"model"` + ImageModel *ModelSelectionConfig `yaml:"image_model"` + ImageGeneration *ModelSelectionConfig `yaml:"image_generation_model"` + PDFModel *ModelSelectionConfig `yaml:"pdf_model"` + PDFEngine string `yaml:"pdf_engine"` + Compaction *airuntime.PruningConfig `yaml:"compaction"` SoulEvil *agents.SoulEvilConfig `yaml:"soul_evil"` Heartbeat *HeartbeatConfig `yaml:"heartbeat"` UserTimezone string `yaml:"user_timezone"` @@ -228,11 +227,16 @@ type SessionConfig struct { // ToolProvidersConfig configures external tool providers like search and fetch. type ToolProvidersConfig struct { - Search *SearchConfig `yaml:"search"` - Fetch *FetchConfig `yaml:"fetch"` - Media *MediaToolsConfig `yaml:"media"` - MCP *MCPToolsConfig `yaml:"mcp"` - VFS *VFSToolsConfig `yaml:"vfs"` + Web *WebToolsConfig `yaml:"web"` + Links *LinkPreviewConfig `yaml:"links"` + Media *MediaToolsConfig `yaml:"media"` + MCP *MCPToolsConfig `yaml:"mcp"` + VFS *VFSToolsConfig `yaml:"vfs"` +} + +type WebToolsConfig struct { + Search *SearchConfig `yaml:"search"` + Fetch *FetchConfig `yaml:"fetch"` } // MCPToolsConfig configures generic MCP behavior. @@ -422,21 +426,6 @@ type BeeperConfig struct { Token string `yaml:"token"` // Beeper Matrix access token } -// ProviderConfig holds settings for a specific AI provider. -type ProviderConfig struct { - APIKey string `yaml:"api_key"` - BaseURL string `yaml:"base_url"` - DefaultModel string `yaml:"default_model"` - DefaultPDFEngine string `yaml:"default_pdf_engine"` // pdf-text, mistral-ocr (default), native -} - -// ProvidersConfig contains per-provider configuration. -type ProvidersConfig struct { - Beeper ProviderConfig `yaml:"beeper"` - OpenAI ProviderConfig `yaml:"openai"` - OpenRouter ProviderConfig `yaml:"openrouter"` -} - // ModelsConfig configures model catalog seeding. type ModelsConfig struct { Mode string `yaml:"mode"` // merge | replace @@ -445,7 +434,22 @@ type ModelsConfig struct { // ModelProviderConfig describes models for a specific provider. type ModelProviderConfig struct { - Models []ModelDefinitionConfig `yaml:"models"` + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + Headers map[string]string `yaml:"headers"` + Models []ModelDefinitionConfig `yaml:"models"` +} + +type ModelSelectionConfig struct { + Primary string `yaml:"primary"` + Fallbacks []string `yaml:"fallbacks"` +} + +func (cfg *ModelsConfig) Provider(name string) ModelProviderConfig { + if cfg == nil || len(cfg.Providers) == 0 { + return ModelProviderConfig{} + } + return cfg.Providers[strings.ToLower(strings.TrimSpace(name))] } // ModelDefinitionConfig defines a model entry for catalog seeding. @@ -467,16 +471,18 @@ func upgradeConfig(helper configupgrade.Helper) { helper.Copy(configupgrade.Str, "beeper", "base_url") helper.Copy(configupgrade.Str, "beeper", "token") - // Per-provider default models - helper.Copy(configupgrade.Str, "providers", "beeper", "default_model") - helper.Copy(configupgrade.Str, "providers", "beeper", "default_pdf_engine") - helper.Copy(configupgrade.Str, "providers", "openai", "api_key") - helper.Copy(configupgrade.Str, "providers", "openai", "base_url") - helper.Copy(configupgrade.Str, "providers", "openai", "default_model") - helper.Copy(configupgrade.Str, "providers", "openrouter", "api_key") - helper.Copy(configupgrade.Str, "providers", "openrouter", "base_url") - helper.Copy(configupgrade.Str, "providers", "openrouter", "default_model") - helper.Copy(configupgrade.Str, "providers", "openrouter", "default_pdf_engine") + // Model providers and defaults + helper.Copy(configupgrade.Str, "models", "mode") + helper.Copy(configupgrade.Map, "models", "providers") + helper.Copy(configupgrade.Str, "agents", "defaults", "model", "primary") + helper.Copy(configupgrade.List, "agents", "defaults", "model", "fallbacks") + helper.Copy(configupgrade.Str, "agents", "defaults", "image_model", "primary") + helper.Copy(configupgrade.List, "agents", "defaults", "image_model", "fallbacks") + helper.Copy(configupgrade.Str, "agents", "defaults", "image_generation_model", "primary") + helper.Copy(configupgrade.List, "agents", "defaults", "image_generation_model", "fallbacks") + helper.Copy(configupgrade.Str, "agents", "defaults", "pdf_model", "primary") + helper.Copy(configupgrade.List, "agents", "defaults", "pdf_model", "fallbacks") + helper.Copy(configupgrade.Str, "agents", "defaults", "pdf_engine") // Global settings helper.Copy(configupgrade.Str, "default_system_prompt") @@ -493,47 +499,35 @@ func upgradeConfig(helper configupgrade.Helper) { // Bridge-specific configuration helper.Copy(configupgrade.Str, "bridge", "command_prefix") - // Context pruning configuration - helper.Copy(configupgrade.Str, "pruning", "mode") - helper.Copy(configupgrade.Str, "pruning", "ttl") - helper.Copy(configupgrade.Bool, "pruning", "enabled") - helper.Copy(configupgrade.Float, "pruning", "soft_trim_ratio") - helper.Copy(configupgrade.Float, "pruning", "hard_clear_ratio") - helper.Copy(configupgrade.Int, "pruning", "keep_last_assistants") - helper.Copy(configupgrade.Int, "pruning", "min_prunable_chars") - helper.Copy(configupgrade.Int, "pruning", "soft_trim_max_chars") - helper.Copy(configupgrade.Int, "pruning", "soft_trim_head_chars") - helper.Copy(configupgrade.Int, "pruning", "soft_trim_tail_chars") - helper.Copy(configupgrade.Bool, "pruning", "hard_clear_enabled") - helper.Copy(configupgrade.Str, "pruning", "hard_clear_placeholder") - - // Compaction configuration (LLM summarization) - helper.Copy(configupgrade.Bool, "pruning", "summarization_enabled") - helper.Copy(configupgrade.Str, "pruning", "summarization_model") - helper.Copy(configupgrade.Int, "pruning", "max_summary_tokens") - helper.Copy(configupgrade.Str, "pruning", "compaction_mode") - helper.Copy(configupgrade.Int, "pruning", "keep_recent_tokens") - helper.Copy(configupgrade.Float, "pruning", "max_history_share") - helper.Copy(configupgrade.Int, "pruning", "reserve_tokens") - helper.Copy(configupgrade.Int, "pruning", "reserve_tokens_floor") - helper.Copy(configupgrade.Str, "pruning", "custom_instructions") - helper.Copy(configupgrade.Str, "pruning", "identifier_policy") - helper.Copy(configupgrade.Str, "pruning", "identifier_instructions") - helper.Copy(configupgrade.Str, "pruning", "post_compaction_refresh_prompt") - helper.Copy(configupgrade.Bool, "pruning", "overflow_flush", "enabled") - helper.Copy(configupgrade.Int, "pruning", "overflow_flush", "soft_threshold_tokens") - helper.Copy(configupgrade.Str, "pruning", "overflow_flush", "prompt") - helper.Copy(configupgrade.Str, "pruning", "overflow_flush", "system_prompt") - - // Link preview configuration - helper.Copy(configupgrade.Bool, "link_previews", "enabled") - helper.Copy(configupgrade.Int, "link_previews", "max_urls_inbound") - helper.Copy(configupgrade.Int, "link_previews", "max_urls_outbound") - helper.Copy(configupgrade.Str, "link_previews", "fetch_timeout") - helper.Copy(configupgrade.Int, "link_previews", "max_content_chars") - helper.Copy(configupgrade.Int, "link_previews", "max_page_bytes") - helper.Copy(configupgrade.Int, "link_previews", "max_image_bytes") - helper.Copy(configupgrade.Str, "link_previews", "cache_ttl") + // Compaction configuration + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "mode") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "ttl") + helper.Copy(configupgrade.Bool, "agents", "defaults", "compaction", "enabled") + helper.Copy(configupgrade.Float, "agents", "defaults", "compaction", "soft_trim_ratio") + helper.Copy(configupgrade.Float, "agents", "defaults", "compaction", "hard_clear_ratio") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "keep_last_assistants") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "min_prunable_chars") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "soft_trim_max_chars") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "soft_trim_head_chars") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "soft_trim_tail_chars") + helper.Copy(configupgrade.Bool, "agents", "defaults", "compaction", "hard_clear_enabled") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "hard_clear_placeholder") + helper.Copy(configupgrade.Bool, "agents", "defaults", "compaction", "summarization_enabled") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "summarization_model") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "max_summary_tokens") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "compaction_mode") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "keep_recent_tokens") + helper.Copy(configupgrade.Float, "agents", "defaults", "compaction", "max_history_share") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "reserve_tokens") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "reserve_tokens_floor") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "custom_instructions") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "identifier_policy") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "identifier_instructions") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "post_compaction_refresh_prompt") + helper.Copy(configupgrade.Bool, "agents", "defaults", "compaction", "overflow_flush", "enabled") + helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "overflow_flush", "soft_threshold_tokens") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "overflow_flush", "prompt") + helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "overflow_flush", "system_prompt") // Inbound message processing configuration helper.Copy(configupgrade.Str, "inbound", "dedupe_ttl") @@ -597,32 +591,40 @@ func upgradeConfig(helper configupgrade.Helper) { helper.Copy(configupgrade.Str, "channels", "matrix", "reply_to_mode") helper.Copy(configupgrade.Str, "channels", "matrix", "thread_replies") - // Tools (search + fetch) - helper.Copy(configupgrade.Str, "tools", "search", "provider") - helper.Copy(configupgrade.List, "tools", "search", "fallbacks") - helper.Copy(configupgrade.Bool, "tools", "search", "exa", "enabled") - helper.Copy(configupgrade.Str, "tools", "search", "exa", "base_url") - helper.Copy(configupgrade.Str, "tools", "search", "exa", "api_key") - helper.Copy(configupgrade.Str, "tools", "search", "exa", "type") - helper.Copy(configupgrade.Str, "tools", "search", "exa", "category") - helper.Copy(configupgrade.Int, "tools", "search", "exa", "num_results") - helper.Copy(configupgrade.Bool, "tools", "search", "exa", "include_text") - helper.Copy(configupgrade.Int, "tools", "search", "exa", "text_max_chars") - helper.Copy(configupgrade.Bool, "tools", "search", "exa", "highlights") - helper.Copy(configupgrade.Str, "tools", "fetch", "provider") - helper.Copy(configupgrade.List, "tools", "fetch", "fallbacks") - helper.Copy(configupgrade.Bool, "tools", "fetch", "exa", "enabled") - helper.Copy(configupgrade.Str, "tools", "fetch", "exa", "base_url") - helper.Copy(configupgrade.Str, "tools", "fetch", "exa", "api_key") - helper.Copy(configupgrade.Bool, "tools", "fetch", "exa", "include_text") - helper.Copy(configupgrade.Int, "tools", "fetch", "exa", "text_max_chars") - helper.Copy(configupgrade.Bool, "tools", "fetch", "direct", "enabled") - helper.Copy(configupgrade.Int, "tools", "fetch", "direct", "timeout_seconds") - helper.Copy(configupgrade.Str, "tools", "fetch", "direct", "user_agent") - helper.Copy(configupgrade.Bool, "tools", "fetch", "direct", "readability") - helper.Copy(configupgrade.Int, "tools", "fetch", "direct", "max_chars") - helper.Copy(configupgrade.Int, "tools", "fetch", "direct", "max_redirects") - helper.Copy(configupgrade.Int, "tools", "fetch", "direct", "cache_ttl_seconds") + // Tools (web + links) + helper.Copy(configupgrade.Str, "tools", "web", "search", "provider") + helper.Copy(configupgrade.List, "tools", "web", "search", "fallbacks") + helper.Copy(configupgrade.Bool, "tools", "web", "search", "exa", "enabled") + helper.Copy(configupgrade.Str, "tools", "web", "search", "exa", "base_url") + helper.Copy(configupgrade.Str, "tools", "web", "search", "exa", "api_key") + helper.Copy(configupgrade.Str, "tools", "web", "search", "exa", "type") + helper.Copy(configupgrade.Str, "tools", "web", "search", "exa", "category") + helper.Copy(configupgrade.Int, "tools", "web", "search", "exa", "num_results") + helper.Copy(configupgrade.Bool, "tools", "web", "search", "exa", "include_text") + helper.Copy(configupgrade.Int, "tools", "web", "search", "exa", "text_max_chars") + helper.Copy(configupgrade.Bool, "tools", "web", "search", "exa", "highlights") + helper.Copy(configupgrade.Str, "tools", "web", "fetch", "provider") + helper.Copy(configupgrade.List, "tools", "web", "fetch", "fallbacks") + helper.Copy(configupgrade.Bool, "tools", "web", "fetch", "exa", "enabled") + helper.Copy(configupgrade.Str, "tools", "web", "fetch", "exa", "base_url") + helper.Copy(configupgrade.Str, "tools", "web", "fetch", "exa", "api_key") + helper.Copy(configupgrade.Bool, "tools", "web", "fetch", "exa", "include_text") + helper.Copy(configupgrade.Int, "tools", "web", "fetch", "exa", "text_max_chars") + helper.Copy(configupgrade.Bool, "tools", "web", "fetch", "direct", "enabled") + helper.Copy(configupgrade.Int, "tools", "web", "fetch", "direct", "timeout_seconds") + helper.Copy(configupgrade.Str, "tools", "web", "fetch", "direct", "user_agent") + helper.Copy(configupgrade.Bool, "tools", "web", "fetch", "direct", "readability") + helper.Copy(configupgrade.Int, "tools", "web", "fetch", "direct", "max_chars") + helper.Copy(configupgrade.Int, "tools", "web", "fetch", "direct", "max_redirects") + helper.Copy(configupgrade.Int, "tools", "web", "fetch", "direct", "cache_ttl_seconds") + helper.Copy(configupgrade.Bool, "tools", "links", "enabled") + helper.Copy(configupgrade.Int, "tools", "links", "max_urls_inbound") + helper.Copy(configupgrade.Int, "tools", "links", "max_urls_outbound") + helper.Copy(configupgrade.Str, "tools", "links", "fetch_timeout") + helper.Copy(configupgrade.Int, "tools", "links", "max_content_chars") + helper.Copy(configupgrade.Int, "tools", "links", "max_page_bytes") + helper.Copy(configupgrade.Int, "tools", "links", "max_image_bytes") + helper.Copy(configupgrade.Str, "tools", "links", "cache_ttl") helper.Copy(configupgrade.Bool, "tools", "mcp", "enable_stdio") helper.Copy(configupgrade.Int, "tools", "media", "image", "max_bytes") helper.Copy(configupgrade.Int, "tools", "media", "image", "max_chars") diff --git a/bridges/ai/integrations_example-config.yaml b/bridges/ai/integrations_example-config.yaml index cab66e41..f51aa046 100644 --- a/bridges/ai/integrations_example-config.yaml +++ b/bridges/ai/integrations_example-config.yaml @@ -1,157 +1,101 @@ # Connector-specific configuration lives under the `network:` section of the # main config file. -# Beeper Cloud credentials for automatic login (optional). -# If user_mxid, base_url, and token are set, users don't need to manually log in. beeper: - user_mxid: "" # Owning Matrix user for the built-in Beeper Cloud login. - base_url: "" # Optional. If empty, login uses selected Beeper domain. - token: "" # Beeper Matrix access token + user_mxid: "" + base_url: "" + token: "" -# Per-provider default models and settings. -# These are used when a room doesn't have a specific model configured. -providers: - beeper: - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - openai: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://api.openai.com/v1 - base_url: "https://api.openai.com/v1" - default_model: "openai/gpt-5.4" - openrouter: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://openrouter.ai/api/v1 - base_url: "https://openrouter.ai/api/v1" - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - -# Optional model catalog seeding. -# models: -# mode: "merge" # merge | replace -# providers: -# openai: -# models: -# - id: "gpt-5.2" -# name: "GPT-5.2" -# reasoning: true -# input: ["text", "image"] -# context_window: 128000 -# max_tokens: 8192 +models: + providers: + openai: + api_key: "" + base_url: "https://api.openai.com/v1" + models: [] + openrouter: + api_key: "" + base_url: "https://openrouter.ai/api/v1" + models: [] + magic_proxy: + api_key: "" + base_url: "" + models: [] -# Global settings default_system_prompt: | You are a helpful, concise assistant. Ask clarifying questions when needed. Follow the user's intent and be accurate. model_cache_duration: 6h -# Optional message rendering settings. messages: - # History defaults for prompt construction. - # Set 0 to disable. direct_chat: history_limit: 20 group_chat: history_limit: 50 - # Queue behavior while the agent is busy. queue: - # Modes: collect, followup, steer, steer-backlog, interrupt mode: "collect" - # Debounce time before draining queued messages (ms). debounce_ms: 1000 - # Maximum queued messages before drop policy applies. cap: 20 - # Drop policy when cap is exceeded: summarize, old, new drop: "summarize" -# Command authorization settings. commands: - # Optional allowlist for owner-only tools/commands (Matrix IDs, or "matrix:@user:server"). owner_allow_from: [] -# Tool approval gating. tool_approvals: enabled: true ttl_seconds: 600 require_for_mcp: true - # List of builtin tool names that require approval (subject to per-tool action allowlists). - # Note: `message` approvals apply to Desktop API routing too (e.g. action=send/reply/edit with desktop chat hints), - # while Desktop read-only actions like desktop-search-* do not require approval. require_for_tools: ["message", "cron", "gravatar_set", "create_agent", "fork_agent", "edit_agent", "delete_agent", "modify_room", "sessions_send", "sessions_spawn", "run_internal_command"] - # Fallback when approval times out: "deny" (default) | "allow". - # Set to "allow" for cron/automated contexts where no human can respond. -# Optional per-channel overrides. channels: matrix: - # Matrix reply/thread behavior. reply_to_mode: "first" -# Session configuration. session: - # Scope for session state: per-sender (default) or global. scope: "per-sender" - # Main session key alias (default: "main"). main_key: "main" -# External tool providers (search + fetch). Proxy is optional. tools: - search: - provider: "openrouter" - fallbacks: ["exa", "brave", "perplexity"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - type: "auto" - num_results: 5 - include_text: false - text_max_chars: 500 - highlights: true # enabled by default; provides description snippets for source cards - brave: - api_key: "" - base_url: "https://api.search.brave.com/res/v1/web/search" - perplexity: - api_key: "" - base_url: "https://openrouter.ai/api/v1" - model: "perplexity/sonar-pro" - openrouter: - api_key: "" - base_url: "https://openrouter.ai/api/v1" - model: "openai/gpt-5.4" - fetch: - provider: "exa" - fallbacks: ["direct"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - include_text: true - text_max_chars: 5000 - direct: - enabled: true - timeout_seconds: 30 - max_chars: 50000 - max_redirects: 3 - - # Generic MCP behavior. + web: + search: + provider: "exa" + fallbacks: [] + exa: + api_key: "" + base_url: "https://api.exa.ai" + type: "auto" + num_results: 5 + include_text: false + text_max_chars: 500 + highlights: true + fetch: + provider: "direct" + fallbacks: ["exa"] + exa: + api_key: "" + base_url: "https://api.exa.ai" + include_text: true + text_max_chars: 5000 + direct: + enabled: true + timeout_seconds: 30 + max_chars: 50000 + max_redirects: 3 + links: + enabled: true + max_urls_inbound: 3 + max_urls_outbound: 5 + fetch_timeout: 10s + max_content_chars: 500 + max_page_bytes: 10485760 + max_image_bytes: 5242880 + cache_ttl: 1h mcp: - # Disabled by default for safety. Enable explicitly to allow local stdio MCP servers. enable_stdio: false - - # Virtual filesystem tools. vfs: apply_patch: enabled: false allow_models: [] - - # Media understanding/transcription. - # Supports provider/CLI entries and per-capability defaults. media: concurrency: 2 image: @@ -167,7 +111,6 @@ tools: enabled: true prompt: "Transcribe the audio." language: "" - # CLI transcription auto-detection (whisper/whisper.cpp) is not implemented yet. max_bytes: 20971520 timeout_seconds: 60 models: @@ -200,165 +143,51 @@ tools: candidate_multiplier: 4 cache: enabled: true - max_entries: -1 # Unlimited when cache.enabled is true; experimental.session_memory does not change cache size semantics. + max_entries: -1 experimental: session_memory: false -# Recall configuration. -# recall: -# citations: "auto" # auto | on | off -# inject_context: false # default false. when true, injects MEMORY.md snippets as extra system context. - - # Tool policy. Controls allow/deny lists and profiles. - # tool_policy: - # profile: "full" - # # group:openclaw is the strict OpenClaw native tool set. - # # group:ai-bridge includes ai-bridge-only extras (beeper_docs, gravatar_*, tts, image_generate, calculator, etc). - # allow: ["group:openclaw", "group:ai-bridge"] - # deny: [] - # subagents: - # tools: - # deny: ["sessions_list", "sessions_history", "sessions_send"] - - # Agent defaults. - # agents: - # defaults: - # subagents: - # model: "anthropic/claude-sonnet-4.5" - # allow_agents: ["*"] - # skip_bootstrap: false - # bootstrap_max_chars: 20000 - # typing_mode: "instant" # never|instant|thinking|message (message ignores NO_REPLY) - # typing_interval_seconds: 6 # refresh cadence, not start time (heartbeats never show typing) - # soul_evil: - # file: "SOUL_EVIL.md" - # chance: 0.1 - # purge: - # at: "21:00" - # duration: "15m" - -# Context pruning configuration. -# Reduces token usage by intelligently truncating old tool results. -pruning: - # Pruning mode: off | cache-ttl - # cache-ttl is the default pruning mode. - mode: "cache-ttl" - - # Refresh interval for cache-ttl mode. - ttl: "1h" - - # Enable proactive context pruning - enabled: true - - # Ratio of context window usage that triggers soft trimming (0.0-1.0) - # At 30% usage, large tool results start getting truncated - soft_trim_ratio: 0.3 - - # Ratio of context window usage that triggers hard clearing (0.0-1.0) - # At 50% usage, old tool results are replaced with placeholder - hard_clear_ratio: 0.5 - - # Number of recent assistant messages to protect from pruning - keep_last_assistants: 3 - - # Minimum total chars in prunable tool results before hard clear kicks in - min_prunable_chars: 50000 - - # Tool results larger than this are candidates for soft trimming - soft_trim_max_chars: 4000 - - # When soft trimming, keep this many chars from the start - soft_trim_head_chars: 1500 - - # When soft trimming, keep this many chars from the end - soft_trim_tail_chars: 1500 - - # Enable/disable hard clear phase - hard_clear_enabled: true - - # Placeholder text for hard-cleared tool results - hard_clear_placeholder: "[Old tool result content cleared]" - - # Tool patterns to allow/deny pruning (supports wildcards: list_*, *_search) - # Empty means all tools are prunable unless denied - # tools_allow: [] - # tools_deny: [] - - # --- LLM-based summarization (compaction) --- - # When enabled, uses an LLM to generate intelligent summaries of compacted - # content instead of just using placeholder text. This preserves context better. - - # Enable LLM summarization (default: true when pruning is enabled) - summarization_enabled: true - - # Model to use for generating summaries (default: fast model) - summarization_model: "openai/gpt-5.4" - - # Maximum tokens for generated summaries - max_summary_tokens: 500 - - # Compaction mode: - # - default: balanced reduction - # - safeguard: preserves recent context more aggressively - compaction_mode: "safeguard" - - # Minimum recent token budget preserved during safeguard compaction - keep_recent_tokens: 20000 - - # Maximum ratio of context that history can consume (0.0-1.0) - # When exceeded, oldest messages are summarized to fit budget - max_history_share: 0.5 - - # Token budget reserved for compaction output - reserve_tokens: 20000 - # Floor applied to reserve_tokens to avoid aggressive overfill - reserve_tokens_floor: 20000 - - # Optional post-compaction system context injected before retry - post_compaction_refresh_prompt: "[Post-compaction context refresh]\nRe-anchor to the latest user intent and preserve unresolved tasks and identifiers." - - # Additional instructions for the summarization model - # custom_instructions: "Focus on preserving code decisions and TODOs" - - # Identifier preservation policy for summaries: - # - strict (default): preserve opaque identifiers exactly - # - off: no special identifier-preservation instruction - # - custom: use identifier_instructions below - identifier_policy: "strict" - # identifier_instructions: "Keep ticket IDs, hashes, and hostnames unchanged." - - # Optional pre-compaction overflow flush turn. - # Enabled by default. Disable explicitly if you want no pre-flush. - overflow_flush: - enabled: true - soft_threshold_tokens: 4000 - prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." - system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." - -# Link preview configuration. -# Automatically fetches metadata for URLs in messages to provide context to the AI -# and generate rich previews in outgoing AI responses. -link_previews: - # Enable link preview functionality (default: true) - enabled: true - - # Maximum number of URLs to fetch from user messages for AI context (default: 3) - max_urls_inbound: 3 - - # Maximum number of URLs to preview in AI responses (default: 5) - max_urls_outbound: 5 - - # Timeout for fetching each URL (default: 10s) - fetch_timeout: 10s - - # Maximum characters from description to include in context (default: 500) - max_content_chars: 500 - - # Maximum page size to download in bytes (default: 10MB) - max_page_bytes: 10485760 - - # Maximum image size to download in bytes (default: 5MB) - max_image_bytes: 5242880 +agents: + defaults: + model: + primary: "" + fallbacks: [] + image_model: + primary: "" + fallbacks: [] + image_generation_model: + primary: "" + fallbacks: [] + pdf_model: + primary: "" + fallbacks: [] + pdf_engine: "mistral-ocr" + compaction: + mode: "cache-ttl" + ttl: "1h" + enabled: true + soft_trim_ratio: 0.3 + hard_clear_ratio: 0.5 + keep_last_assistants: 3 + min_prunable_chars: 50000 + soft_trim_max_chars: 4000 + soft_trim_head_chars: 1500 + soft_trim_tail_chars: 1500 + hard_clear_enabled: true + hard_clear_placeholder: "[Old tool result content cleared]" + summarization_enabled: true + summarization_model: "openai/gpt-5.2" + max_summary_tokens: 500 + compaction_mode: "safeguard" + keep_recent_tokens: 20000 + max_history_share: 0.5 + reserve_tokens: 20000 + reserve_tokens_floor: 20000 + identifier_policy: "strict" + post_compaction_refresh_prompt: "[Post-compaction context refresh]\nRe-anchor to the latest user intent and preserve unresolved tasks and identifiers." + overflow_flush: + enabled: true + soft_threshold_tokens: 4000 + prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." + system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." - # How long to cache URL previews (default: 1h) - cache_ttl: 1h diff --git a/bridges/ai/internal_dispatch.go b/bridges/ai/internal_dispatch.go index bfc3913d..d4cfd657 100644 --- a/bridges/ai/internal_dispatch.go +++ b/bridges/ai/internal_dispatch.go @@ -44,7 +44,7 @@ func (oc *AIClient) dispatchInternalMessage( inboundCtx := oc.resolvePromptInboundContext(ctx, portal, trimmed, eventID) promptCtx := withInboundContext(ctx, inboundCtx) - promptContext, err := oc.buildContextWithLinkContext(promptCtx, portal, meta, trimmed, nil, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(promptCtx, portal, meta, trimmed, nil, eventID) if err != nil { return eventID, false, err } diff --git a/bridges/ai/login.go b/bridges/ai/login.go index 00532fbd..5f6f2b87 100644 --- a/bridges/ai/login.go +++ b/bridges/ai/login.go @@ -369,8 +369,8 @@ func (ol *OpenAILogin) resolveCustomLogin(input map[string]string) (string, stri if input == nil { input = map[string]string{} } - openrouterCfg := strings.TrimSpace(ol.Connector.Config.Providers.OpenRouter.APIKey) - openaiCfg := strings.TrimSpace(ol.Connector.Config.Providers.OpenAI.APIKey) + openrouterCfg := strings.TrimSpace(ol.Connector.modelProviderConfig(ProviderOpenRouter).APIKey) + openaiCfg := strings.TrimSpace(ol.Connector.modelProviderConfig(ProviderOpenAI).APIKey) openrouterInput := "" openaiInput := "" @@ -478,18 +478,22 @@ func parseMagicProxyLink(raw string) (string, string, error) { } func (ol *OpenAILogin) configHasOpenRouterKey() bool { - return strings.TrimSpace(ol.Connector.Config.Providers.OpenRouter.APIKey) != "" + return strings.TrimSpace(ol.Connector.modelProviderConfig(ProviderOpenRouter).APIKey) != "" } func (ol *OpenAILogin) configHasOpenAIKey() bool { - return strings.TrimSpace(ol.Connector.Config.Providers.OpenAI.APIKey) != "" + return strings.TrimSpace(ol.Connector.modelProviderConfig(ProviderOpenAI).APIKey) != "" } func (ol *OpenAILogin) configHasExaKey() bool { - if ol.Connector.Config.Tools.Search != nil && strings.TrimSpace(ol.Connector.Config.Tools.Search.Exa.APIKey) != "" { + if ol.Connector.Config.Tools.Web != nil && + ol.Connector.Config.Tools.Web.Search != nil && + strings.TrimSpace(ol.Connector.Config.Tools.Web.Search.Exa.APIKey) != "" { return true } - if ol.Connector.Config.Tools.Fetch != nil && strings.TrimSpace(ol.Connector.Config.Tools.Fetch.Exa.APIKey) != "" { + if ol.Connector.Config.Tools.Web != nil && + ol.Connector.Config.Tools.Web.Fetch != nil && + strings.TrimSpace(ol.Connector.Config.Tools.Web.Fetch.Exa.APIKey) != "" { return true } return false diff --git a/bridges/ai/media_understanding_runner.go b/bridges/ai/media_understanding_runner.go index 49962119..fbce68a5 100644 --- a/bridges/ai/media_understanding_runner.go +++ b/bridges/ai/media_understanding_runner.go @@ -921,10 +921,7 @@ func (oc *AIClient) resolveOpenRouterMediaConfig( if baseURL == "" { baseURL = resolveOpenRouterMediaBaseURL(oc) } - pdfEngine = oc.connector.Config.Providers.OpenRouter.DefaultPDFEngine - if pdfEngine == "" { - pdfEngine = "mistral-ocr" - } + pdfEngine = oc.defaultPDFEngine() if oc.UserLogin != nil && oc.UserLogin.User != nil && oc.UserLogin.User.MXID != "" { userID = oc.UserLogin.User.MXID.String() } diff --git a/bridges/ai/media_understanding_runner_openai_test.go b/bridges/ai/media_understanding_runner_openai_test.go index 709a71ea..745307a4 100644 --- a/bridges/ai/media_understanding_runner_openai_test.go +++ b/bridges/ai/media_understanding_runner_openai_test.go @@ -53,13 +53,7 @@ func TestResolveOpenRouterMediaConfigUsesEntryOverrides(t *testing.T) { t.Setenv("OPENROUTER_API_KEY_SPECIAL_PROFILE", "entry-key") client := newMediaTestClient(&UserLoginMetadata{Provider: ProviderOpenAI}, &OpenAIConnector{ - Config: Config{ - Providers: ProvidersConfig{ - OpenRouter: ProviderConfig{ - DefaultPDFEngine: "native", - }, - }, - }, + Config: Config{}, }) cfg := &MediaUnderstandingConfig{ @@ -99,8 +93,8 @@ func TestResolveOpenRouterMediaConfigUsesEntryOverrides(t *testing.T) { if headers["X-Title"] != openRouterAppTitle { t.Fatalf("expected default OpenRouter title header, got %#v", headers) } - if pdfEngine != "native" { - t.Fatalf("expected configured PDF engine, got %q", pdfEngine) + if pdfEngine != "mistral-ocr" { + t.Fatalf("expected default PDF engine, got %q", pdfEngine) } } diff --git a/bridges/ai/prompt_builder.go b/bridges/ai/prompt_builder.go index 0c704d70..91d7ab13 100644 --- a/bridges/ai/prompt_builder.go +++ b/bridges/ai/prompt_builder.go @@ -2,11 +2,13 @@ package ai import ( "context" + "fmt" "strings" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -30,6 +32,22 @@ type currentTurnTextOptions struct { append []string } +type turnAttachmentOptions struct { + mediaURL string + mimeType string + encryptedFile *event.EncryptedFileInfo + mediaType pendingMessageType +} + +type currentTurnPromptOptions struct { + rawEventContent map[string]any + includeLinkScope bool + prepend []string + append []string + leadingBlocks []PromptBlock + attachment *turnAttachmentOptions +} + func joinPromptFragments(parts ...string) string { var filtered []string for _, part := range parts { @@ -176,3 +194,96 @@ func (oc *AIClient) buildCurrentTurnText( body := joinPromptFragments(append(append(prepend, result.ResolvedBody), appendParts...)...) return result.PromptContext, body, nil } + +func (oc *AIClient) buildPromptContextForTurn( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + userText string, + eventID id.EventID, + opts currentTurnPromptOptions, +) (PromptContext, error) { + appendFragments := append([]string{}, opts.append...) + leadingBlocks := append([]PromptBlock{}, opts.leadingBlocks...) + + if opts.attachment != nil { + attachmentBlocks, attachmentAppend, err := oc.normalizeTurnAttachment(ctx, *opts.attachment) + if err != nil { + return PromptContext{}, err + } + leadingBlocks = append(leadingBlocks, attachmentBlocks...) + appendFragments = append(appendFragments, attachmentAppend...) + } + + base, text, err := oc.buildCurrentTurnText(ctx, portal, meta, userText, eventID, currentTurnTextOptions{ + rawEventContent: opts.rawEventContent, + includeLinkScope: opts.includeLinkScope, + prepend: opts.prepend, + append: appendFragments, + }) + if err != nil { + return PromptContext{}, err + } + + blocks := make([]PromptBlock, 0, len(leadingBlocks)+1) + if strings.TrimSpace(text) != "" { + blocks = append(blocks, PromptBlock{Type: PromptBlockText, Text: text}) + } + blocks = append(blocks, leadingBlocks...) + base.Messages = append(base.Messages, PromptMessage{ + Role: PromptRoleUser, + Blocks: blocks, + }) + return base, nil +} + +func (oc *AIClient) normalizeTurnAttachment(ctx context.Context, opts turnAttachmentOptions) ([]PromptBlock, []string, error) { + switch opts.mediaType { + case pendingTypeImage: + b64Data, actualMimeType, err := oc.downloadMediaBase64(ctx, opts.mediaURL, opts.encryptedFile, 20, opts.mimeType) + if err != nil { + return nil, nil, fmt.Errorf("failed to download image: %w", err) + } + return []PromptBlock{{ + Type: PromptBlockImage, + ImageB64: b64Data, + MimeType: actualMimeType, + }}, nil, nil + case pendingTypePDF: + content, truncated, err := oc.downloadPDFFile(ctx, opts.mediaURL, opts.encryptedFile, opts.mimeType) + if err != nil { + return nil, nil, fmt.Errorf("failed to download PDF: %w", err) + } + filename := resolveMediaFileName("document.pdf", "pdf", opts.mediaURL) + return nil, []string{buildTextFileMessage("", false, filename, "application/pdf", content, truncated)}, nil + case pendingTypeAudio: + return nil, nil, fmt.Errorf("audio attachments must be preprocessed into text before prompt assembly") + case pendingTypeVideo: + return nil, nil, fmt.Errorf("video attachments must be preprocessed into text before prompt assembly") + default: + return nil, nil, fmt.Errorf("unsupported media type: %s", opts.mediaType) + } +} + +func (oc *AIClient) buildCurrentTurnWithLinks( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + userText string, + rawEventContent map[string]any, + eventID id.EventID, +) (PromptContext, error) { + return oc.buildPromptContextForTurn(ctx, portal, meta, userText, eventID, currentTurnPromptOptions{ + rawEventContent: rawEventContent, + includeLinkScope: true, + }) +} + +func (oc *AIClient) buildHeartbeatTurnContext( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + prompt string, +) (PromptContext, error) { + return oc.buildPromptContextForTurn(ctx, portal, meta, prompt, "", currentTurnPromptOptions{}) +} diff --git a/bridges/ai/response_retry_test.go b/bridges/ai/response_retry_test.go index 080d807b..a6827ed2 100644 --- a/bridges/ai/response_retry_test.go +++ b/bridges/ai/response_retry_test.go @@ -23,7 +23,9 @@ func newPruningTestClient(pruning *airuntime.PruningConfig, provider string) *AI }, connector: &OpenAIConnector{ Config: Config{ - Pruning: pruning, + Agents: &AgentsConfig{Defaults: &AgentDefaultsConfig{ + Compaction: pruning, + }}, }, }, log: zerolog.Nop(), diff --git a/bridges/ai/runtime_compaction_adapter.go b/bridges/ai/runtime_compaction_adapter.go index 2ce1362b..33fc2ccc 100644 --- a/bridges/ai/runtime_compaction_adapter.go +++ b/bridges/ai/runtime_compaction_adapter.go @@ -30,8 +30,10 @@ type CompactionEvent struct { } func (oc *AIClient) pruningConfigOrDefault() *airuntime.PruningConfig { - if oc != nil && oc.connector != nil && oc.connector.Config.Pruning != nil { - return airuntime.ApplyPruningDefaults(oc.connector.Config.Pruning) + if oc != nil && oc.connector != nil && oc.connector.Config.Agents != nil && + oc.connector.Config.Agents.Defaults != nil && + oc.connector.Config.Agents.Defaults.Compaction != nil { + return airuntime.ApplyPruningDefaults(oc.connector.Config.Agents.Defaults.Compaction) } return airuntime.DefaultPruningConfig() } diff --git a/bridges/ai/runtime_defaults_test.go b/bridges/ai/runtime_defaults_test.go index a06e32e9..a6acc25e 100644 --- a/bridges/ai/runtime_defaults_test.go +++ b/bridges/ai/runtime_defaults_test.go @@ -18,42 +18,44 @@ func TestApplyRuntimeDefaultsSetsPruningDefaults(t *testing.T) { if connector.Config.Bridge.CommandPrefix != "!ai" { t.Fatalf("expected command prefix !ai, got %q", connector.Config.Bridge.CommandPrefix) } - if connector.Config.Pruning == nil { - t.Fatal("expected pruning defaults to be initialized") + if connector.Config.Agents == nil || connector.Config.Agents.Defaults == nil || connector.Config.Agents.Defaults.Compaction == nil { + t.Fatal("expected compaction defaults to be initialized") } - if !connector.Config.Pruning.Enabled { + if !connector.Config.Agents.Defaults.Compaction.Enabled { t.Fatal("expected pruning defaults enabled") } - if connector.Config.Pruning.Mode != "cache-ttl" { - t.Fatalf("expected pruning mode cache-ttl, got %q", connector.Config.Pruning.Mode) + if connector.Config.Agents.Defaults.Compaction.Mode != "cache-ttl" { + t.Fatalf("expected pruning mode cache-ttl, got %q", connector.Config.Agents.Defaults.Compaction.Mode) } - if connector.Config.Pruning.TTL != time.Hour { - t.Fatalf("expected pruning ttl 1h, got %v", connector.Config.Pruning.TTL) + if connector.Config.Agents.Defaults.Compaction.TTL != time.Hour { + t.Fatalf("expected pruning ttl 1h, got %v", connector.Config.Agents.Defaults.Compaction.TTL) } } func TestApplyRuntimeDefaultsKeepsExplicitPruningModeOff(t *testing.T) { connector := &OpenAIConnector{ Config: Config{ - Pruning: &airuntime.PruningConfig{ - Mode: "off", - Enabled: false, - }, + Agents: &AgentsConfig{Defaults: &AgentDefaultsConfig{ + Compaction: &airuntime.PruningConfig{ + Mode: "off", + Enabled: false, + }, + }}, }, } connector.applyRuntimeDefaults() - if connector.Config.Pruning == nil { + if connector.Config.Agents == nil || connector.Config.Agents.Defaults == nil || connector.Config.Agents.Defaults.Compaction == nil { t.Fatal("expected pruning config to remain set") } - if connector.Config.Pruning.Mode != "off" { - t.Fatalf("expected pruning mode off to be preserved, got %q", connector.Config.Pruning.Mode) + if connector.Config.Agents.Defaults.Compaction.Mode != "off" { + t.Fatalf("expected pruning mode off to be preserved, got %q", connector.Config.Agents.Defaults.Compaction.Mode) } - if connector.Config.Pruning.Enabled { + if connector.Config.Agents.Defaults.Compaction.Enabled { t.Fatal("expected pruning enabled=false to be preserved") } - if connector.Config.Pruning.SoftTrimRatio <= 0 { + if connector.Config.Agents.Defaults.Compaction.SoftTrimRatio <= 0 { t.Fatal("expected missing pruning numeric defaults to be filled") } } diff --git a/bridges/ai/status_text.go b/bridges/ai/status_text.go index 10f00307..0ab7b12b 100644 --- a/bridges/ai/status_text.go +++ b/bridges/ai/status_text.go @@ -224,11 +224,11 @@ func (oc *AIClient) estimatePromptTokens(ctx context.Context, portal *bridgev2.P if oc == nil || portal == nil { return 0 } - prompt, err := oc.buildBasePrompt(ctx, portal, meta) + promptContext, err := oc.buildBaseContext(ctx, portal, meta) if err != nil { return 0 } - prompt = oc.augmentPromptWithIntegrations(ctx, portal, meta, prompt) + prompt := oc.promptContextToDispatchMessages(ctx, portal, meta, promptContext) modelID := oc.effectiveModel(meta) count, err := EstimateTokens(prompt, modelID) if err != nil { diff --git a/bridges/ai/streaming_request_tools_test.go b/bridges/ai/streaming_request_tools_test.go index 02b7eba3..f067123f 100644 --- a/bridges/ai/streaming_request_tools_test.go +++ b/bridges/ai/streaming_request_tools_test.go @@ -14,9 +14,9 @@ func testToolSelectionClient(supportsToolCalling bool) *AIClient { connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Search: &SearchConfig{ + Web: &WebToolsConfig{Search: &SearchConfig{ Exa: ProviderExaConfig{APIKey: "test"}, - }, + }}, }, }, }, diff --git a/bridges/ai/streaming_tool_selection_test.go b/bridges/ai/streaming_tool_selection_test.go index 054a6345..9488e415 100644 --- a/bridges/ai/streaming_tool_selection_test.go +++ b/bridges/ai/streaming_tool_selection_test.go @@ -14,9 +14,9 @@ func TestSelectedBuiltinToolsForTurn_AgentRoomExposesBuiltinTools(t *testing.T) connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Search: &SearchConfig{ + Web: &WebToolsConfig{Search: &SearchConfig{ Exa: ProviderExaConfig{APIKey: "test-key"}, - }, + }}, }, }, }, @@ -46,9 +46,9 @@ func TestSelectedBuiltinToolsForTurn_ModelRoomGetsNoTools(t *testing.T) { connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Search: &SearchConfig{ + Web: &WebToolsConfig{Search: &SearchConfig{ Exa: ProviderExaConfig{APIKey: "test-key"}, - }, + }}, }, }, }, diff --git a/bridges/ai/subagent_spawn.go b/bridges/ai/subagent_spawn.go index ef66175d..bdadca23 100644 --- a/bridges/ai/subagent_spawn.go +++ b/bridges/ai/subagent_spawn.go @@ -331,7 +331,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P } eventID := agentremote.NewEventID("subagent") - promptContext, err := oc.buildContextWithLinkContext(ctx, childPortal, childMeta, task, nil, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(ctx, childPortal, childMeta, task, nil, eventID) if err != nil { return tools.JSONResult(map[string]any{ "status": "error", diff --git a/bridges/ai/token_resolver.go b/bridges/ai/token_resolver.go index 9bd783f7..abacc936 100644 --- a/bridges/ai/token_resolver.go +++ b/bridges/ai/token_resolver.go @@ -26,6 +26,13 @@ type ServiceConfig struct { type ServiceConfigMap map[string]ServiceConfig +func (oc *OpenAIConnector) modelProviderConfig(provider string) ModelProviderConfig { + if oc == nil || oc.Config.Models == nil { + return ModelProviderConfig{} + } + return oc.Config.Models.Provider(provider) +} + func trimToken(value string) string { return strings.TrimSpace(value) } @@ -118,7 +125,7 @@ func (oc *OpenAIConnector) resolveExaProxyBaseURL(meta *UserLoginMetadata) strin } func (oc *OpenAIConnector) resolveOpenAIBaseURL() string { - base := strings.TrimSpace(oc.Config.Providers.OpenAI.BaseURL) + base := strings.TrimSpace(oc.modelProviderConfig(ProviderOpenAI).BaseURL) if base == "" { base = defaultOpenAIBaseURL } @@ -126,7 +133,7 @@ func (oc *OpenAIConnector) resolveOpenAIBaseURL() string { } func (oc *OpenAIConnector) resolveOpenRouterBaseURL() string { - base := strings.TrimSpace(oc.Config.Providers.OpenRouter.BaseURL) + base := strings.TrimSpace(oc.modelProviderConfig(ProviderOpenRouter).BaseURL) if base == "" { base = defaultOpenRouterBaseURL } @@ -190,7 +197,7 @@ func (oc *OpenAIConnector) resolveProviderAPIKey(meta *UserLoginMetadata) string return trimToken(meta.ServiceTokens.OpenRouter) } case ProviderOpenRouter: - if key := trimToken(oc.Config.Providers.OpenRouter.APIKey); key != "" { + if key := trimToken(oc.modelProviderConfig(ProviderOpenRouter).APIKey); key != "" { return key } if key := trimToken(meta.APIKey); key != "" { @@ -200,7 +207,7 @@ func (oc *OpenAIConnector) resolveProviderAPIKey(meta *UserLoginMetadata) string return trimToken(meta.ServiceTokens.OpenRouter) } case ProviderOpenAI: - if key := trimToken(oc.Config.Providers.OpenAI.APIKey); key != "" { + if key := trimToken(oc.modelProviderConfig(ProviderOpenAI).APIKey); key != "" { return key } if key := trimToken(meta.APIKey); key != "" { @@ -216,7 +223,7 @@ func (oc *OpenAIConnector) resolveProviderAPIKey(meta *UserLoginMetadata) string } func (oc *OpenAIConnector) resolveOpenAIAPIKey(meta *UserLoginMetadata) string { - if key := trimToken(oc.Config.Providers.OpenAI.APIKey); key != "" { + if key := trimToken(oc.modelProviderConfig(ProviderOpenAI).APIKey); key != "" { return key } if meta == nil { @@ -234,7 +241,7 @@ func (oc *OpenAIConnector) resolveOpenAIAPIKey(meta *UserLoginMetadata) string { } func (oc *OpenAIConnector) resolveOpenRouterAPIKey(meta *UserLoginMetadata) string { - if key := trimToken(oc.Config.Providers.OpenRouter.APIKey); key != "" { + if key := trimToken(oc.modelProviderConfig(ProviderOpenRouter).APIKey); key != "" { return key } if meta == nil { diff --git a/bridges/ai/tool_availability_configured_test.go b/bridges/ai/tool_availability_configured_test.go index 3e940aab..43e9dd78 100644 --- a/bridges/ai/tool_availability_configured_test.go +++ b/bridges/ai/tool_availability_configured_test.go @@ -21,7 +21,7 @@ func TestToolAvailable_WebSearch_RequiresAnyProviderKey(t *testing.T) { connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Search: &SearchConfig{}, + Web: &WebToolsConfig{Search: &SearchConfig{}}, }, }, }, @@ -48,9 +48,9 @@ func TestToolAvailable_WebSearch_WithProviderKey(t *testing.T) { connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Search: &SearchConfig{ + Web: &WebToolsConfig{Search: &SearchConfig{ Exa: ProviderExaConfig{APIKey: "test"}, - }, + }}, }, }, }, @@ -71,9 +71,9 @@ func TestToolAvailable_WebFetch_DirectDisabledAndNoExaKey(t *testing.T) { connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Fetch: &FetchConfig{ + Web: &WebToolsConfig{Fetch: &FetchConfig{ Direct: ProviderDirectConfig{Enabled: boolPtr(false)}, - }, + }}, }, }, }, diff --git a/bridges/ai/tool_configured.go b/bridges/ai/tool_configured.go index bb8e096d..8b61096a 100644 --- a/bridges/ai/tool_configured.go +++ b/bridges/ai/tool_configured.go @@ -17,10 +17,10 @@ func (oc *AIClient) effectiveSearchConfig(_ context.Context) *search.Config { return effectiveToolConfig( oc, func(connector *OpenAIConnector) *search.Config { - if connector == nil { + if connector == nil || connector.Config.Tools.Web == nil { return nil } - return mapSearchConfig(connector.Config.Tools.Search) + return mapSearchConfig(connector.Config.Tools.Web.Search) }, applyLoginTokensToSearchConfig, func(cfg *search.Config) *search.Config { return search.ApplyEnvDefaults(cfg).WithDefaults() }, @@ -31,10 +31,10 @@ func (oc *AIClient) effectiveFetchConfig(_ context.Context) *fetch.Config { return effectiveToolConfig( oc, func(connector *OpenAIConnector) *fetch.Config { - if connector == nil { + if connector == nil || connector.Config.Tools.Web == nil { return nil } - return mapFetchConfig(connector.Config.Tools.Fetch) + return mapFetchConfig(connector.Config.Tools.Web.Fetch) }, applyLoginTokensToFetchConfig, func(cfg *fetch.Config) *fetch.Config { return fetch.ApplyEnvDefaults(cfg).WithDefaults() }, diff --git a/bridges/ai/tools_tts_test.go b/bridges/ai/tools_tts_test.go index d7924773..9def44c5 100644 --- a/bridges/ai/tools_tts_test.go +++ b/bridges/ai/tools_tts_test.go @@ -60,9 +60,9 @@ func TestResolveOpenAITTSBaseURLOpenAIProviderUsesConfiguredBase(t *testing.T) { meta := &UserLoginMetadata{Provider: ProviderOpenAI} oc := &OpenAIConnector{ Config: Config{ - Providers: ProvidersConfig{ - OpenAI: ProviderConfig{BaseURL: "https://openai.example/v1"}, - }, + Models: &ModelsConfig{Providers: map[string]ModelProviderConfig{ + ProviderOpenAI: {BaseURL: "https://openai.example/v1"}, + }}, }, } btc := newTTSTestBridgeContext(meta, oc) From bd9bb1eb977176bfa2aa254b94e82450f1086117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 21:57:43 +0200 Subject: [PATCH 03/23] Introduce PromptContext abstraction and refactor usage Replace direct openai.ChatCompletionMessageParamUnion usage with a new PromptContext/PromptMessage model across the AI bridge. Add prompt_context_ops helpers, conversions between PromptContext and chat messages, and estimatePromptContextTokensForModel. Update compaction, retry, streaming (chat/responses), subagent, pending queue, status text, and memory integration code to operate on PromptContext, including new memory prompt augmentor wiring and steering prompt builders. This centralizes prompt handling, removes many direct OpenAI type dependencies, and adapts runtime compaction/truncation to the new representation. --- bridges/ai/agent_loop_test.go | 17 +++-- bridges/ai/compaction_summarization.go | 16 +++-- bridges/ai/pending_queue.go | 25 ++++++- bridges/ai/prompt_context_ops.go | 53 +++++++++++++++ bridges/ai/response_retry.go | 86 ++++++++++++------------ bridges/ai/runtime_compaction_adapter.go | 4 ++ bridges/ai/status_text.go | 7 +- bridges/ai/streaming_chat_completions.go | 63 +++++++++++++---- bridges/ai/streaming_executor.go | 25 +++---- bridges/ai/streaming_responses_api.go | 37 +++++----- bridges/ai/subagent_announce.go | 7 +- bridges/ai/subagent_spawn.go | 6 +- bridges/ai/system_prompts.go | 81 ++++++++++------------ pkg/integrations/memory/integration.go | 35 +++++----- pkg/integrations/memory/prompt_exec.go | 52 ++++++-------- pkg/integrations/runtime/interfaces.go | 16 ----- 16 files changed, 305 insertions(+), 225 deletions(-) create mode 100644 bridges/ai/prompt_context_ops.go diff --git a/bridges/ai/agent_loop_test.go b/bridges/ai/agent_loop_test.go index 730b8fdd..0dd3afaf 100644 --- a/bridges/ai/agent_loop_test.go +++ b/bridges/ai/agent_loop_test.go @@ -5,14 +5,13 @@ import ( "errors" "testing" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/event" ) type fakeAgentLoopProvider struct { track bool results []fakeAgentLoopResult - followUps map[int][]openai.ChatCompletionMessageParamUnion + followUps map[int][]PromptMessage finalizeCalls int continueCalls int roundsObserved []int @@ -41,14 +40,14 @@ func (f *fakeAgentLoopProvider) FinalizeAgentLoop(context.Context) { f.finalizeCalls++ } -func (f *fakeAgentLoopProvider) GetFollowUpMessages(_ context.Context) []openai.ChatCompletionMessageParamUnion { +func (f *fakeAgentLoopProvider) GetFollowUpMessages(_ context.Context) []PromptMessage { if len(f.roundsObserved) == 0 { return nil } return f.followUps[f.roundsObserved[len(f.roundsObserved)-1]] } -func (f *fakeAgentLoopProvider) ContinueAgentLoop(messages []openai.ChatCompletionMessageParamUnion) { +func (f *fakeAgentLoopProvider) ContinueAgentLoop(messages []PromptMessage) { if len(messages) > 0 { f.continueCalls++ } @@ -132,8 +131,14 @@ func TestExecuteAgentLoopRoundsContinuesForFollowUpMessages(t *testing.T) { {continueLoop: false}, {continueLoop: false}, }, - followUps: map[int][]openai.ChatCompletionMessageParamUnion{ - 0: {openai.UserMessage("follow up")}, + followUps: map[int][]PromptMessage{ + 0: {{ + Role: PromptRoleUser, + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: "follow up", + }}, + }}, }, } diff --git a/bridges/ai/compaction_summarization.go b/bridges/ai/compaction_summarization.go index 4d425d22..032ab8b2 100644 --- a/bridges/ai/compaction_summarization.go +++ b/bridges/ai/compaction_summarization.go @@ -614,18 +614,20 @@ func injectSystemPromptAtFirstNonSystem( func (oc *AIClient) applyCompactionModelSummaryAndRefresh( ctx context.Context, meta *PortalMetadata, - originalPrompt []openai.ChatCompletionMessageParamUnion, - compactedPrompt []openai.ChatCompletionMessageParamUnion, + originalPrompt PromptContext, + compactedPrompt PromptContext, decision airuntime.CompactionDecision, contextWindowTokens int, -) []openai.ChatCompletionMessageParamUnion { - out := compactedPrompt +) PromptContext { + originalMessages := PromptContextToChatCompletionMessages(originalPrompt, false) + compactedMessages := PromptContextToChatCompletionMessages(compactedPrompt, false) + out := compactedMessages if oc.pruningSummarizationEnabled() { - dropped := selectDroppedCompactionMessages(originalPrompt, compactedPrompt, decision.DroppedCount) + dropped := selectDroppedCompactionMessages(originalMessages, compactedMessages, decision.DroppedCount) if len(dropped) > 0 { model := resolveCompactionSummaryModel(oc.effectiveModel(meta), oc.pruningSummarizationModel()) allMessages := slices.Clone(dropped) - allMessages = append(allMessages, compactedPrompt...) + allMessages = append(allMessages, compactedMessages...) adaptive := computeCompactionAdaptiveChunkRatio(allMessages, model, contextWindowTokens) maxChunkTokens := int(math.Floor(float64(contextWindowTokens)*adaptive)) - compactionSummarizationOverhead if maxChunkTokens <= 0 { @@ -653,5 +655,5 @@ func (oc *AIClient) applyCompactionModelSummaryAndRefresh( if refresh := strings.TrimSpace(oc.pruningPostCompactionRefreshPrompt()); refresh != "" { out = injectSystemPromptAtFirstNonSystem(out, refresh) } - return out + return ChatMessagesToPromptContext(out) } diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index bea251ac..000ce4c7 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -339,7 +339,28 @@ func buildSteeringUserMessages(prompts []string) []openai.ChatCompletionMessageP return messages } -func (oc *AIClient) getFollowUpMessages(roomID id.RoomID) []openai.ChatCompletionMessageParamUnion { +func buildSteeringPromptMessages(prompts []string) []PromptMessage { + if len(prompts) == 0 { + return nil + } + messages := make([]PromptMessage, 0, len(prompts)) + for _, prompt := range prompts { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + continue + } + messages = append(messages, PromptMessage{ + Role: PromptRoleUser, + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: prompt, + }}, + }) + } + return messages +} + +func (oc *AIClient) getFollowUpMessages(roomID id.RoomID) []PromptMessage { if oc == nil || roomID == "" { return nil } @@ -362,7 +383,7 @@ func (oc *AIClient) getFollowUpMessages(roomID id.RoomID) []openai.ChatCompletio if !ok { return nil } - return buildSteeringUserMessages([]string{prompt}) + return buildSteeringPromptMessages([]string{prompt}) } func (oc *AIClient) markQueueDraining(roomID id.RoomID) bool { diff --git a/bridges/ai/prompt_context_ops.go b/bridges/ai/prompt_context_ops.go new file mode 100644 index 00000000..57a1450a --- /dev/null +++ b/bridges/ai/prompt_context_ops.go @@ -0,0 +1,53 @@ +package ai + +import "strings" + +func ClonePromptMessages(messages []PromptMessage) []PromptMessage { + if len(messages) == 0 { + return nil + } + out := make([]PromptMessage, 0, len(messages)) + for _, message := range messages { + cloned := message + if len(message.Blocks) > 0 { + cloned.Blocks = append([]PromptBlock(nil), message.Blocks...) + } + out = append(out, cloned) + } + return out +} + +func ClonePromptContext(ctx PromptContext) PromptContext { + cloned := ctx + cloned.Messages = ClonePromptMessages(ctx.Messages) + if len(ctx.Tools) > 0 { + cloned.Tools = append([]ToolDefinition(nil), ctx.Tools...) + } + return cloned +} + +func AppendPromptMessages(ctx *PromptContext, messages ...PromptMessage) { + if ctx == nil || len(messages) == 0 { + return + } + ctx.Messages = append(ctx.Messages, ClonePromptMessages(messages)...) +} + +func PromptContextMessageCount(ctx PromptContext) int { + count := len(ctx.Messages) + if strings.TrimSpace(ctx.SystemPrompt) != "" { + count++ + } + return count +} + +func NewUserTextPromptMessage(text string) PromptMessage { + return PromptMessage{ + Role: PromptRoleUser, + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: text, + }}, + } +} + diff --git a/bridges/ai/response_retry.go b/bridges/ai/response_retry.go index e7531a62..58cf8ef0 100644 --- a/bridges/ai/response_retry.go +++ b/bridges/ai/response_retry.go @@ -7,7 +7,6 @@ import ( "math" "slices" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" @@ -20,8 +19,7 @@ const ( maxRetryAttempts = 3 // Maximum retry attempts for context length errors ) -// responseFunc is the signature for response handlers that can be retried on context length errors -type responseFunc func(ctx context.Context, evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, prompt []openai.ChatCompletionMessageParamUnion) (bool, *ContextLengthError, error) +type responseFuncCanonical func(ctx context.Context, evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, prompt PromptContext) (bool, *ContextLengthError, error) // responseWithRetry wraps a response function with context length retry logic. // It performs one runtime compaction retry attempt. @@ -30,11 +28,11 @@ func (oc *AIClient) responseWithRetry( evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, - responseFn responseFunc, + prompt PromptContext, + responseFn responseFuncCanonical, logLabel string, ) (bool, error) { - currentPrompt := prompt + currentPrompt := ClonePromptContext(prompt) preflightFlushAttempted := false overflowCompactionAttempts := 0 var lastCLE *ContextLengthError @@ -71,7 +69,7 @@ func (oc *AIClient) responseWithRetry( if meta != nil { modelID = oc.effectiveModel(meta) } - tokensBefore := estimatePromptTokensForModel(currentPrompt, modelID) + tokensBefore := estimatePromptContextTokensForModel(currentPrompt, modelID) if overflowCompactionAttempts < maxRetryAttempts { overflowCompactionAttempts++ @@ -87,19 +85,19 @@ func (oc *AIClient) responseWithRetry( ContextWindowTokens: contextWindow, RequestedTokens: cle.RequestedTokens, PromptTokens: tokensBefore, - MessagesBefore: len(currentPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), TokensBefore: tokensBefore, }) oc.emitCompactionStatus(ctx, portal, &CompactionEvent{ Type: CompactionEventStart, SessionID: sessionID, - MessagesBefore: len(currentPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), }) compacted, decision, compactionSuccess := oc.runtimeCompactOnOverflow(currentPrompt, contextWindow, cle.RequestedTokens, tokensBefore) - if compactionSuccess && len(compacted) > 2 { + if compactionSuccess && PromptContextMessageCount(compacted) > 2 { compacted = oc.applyCompactionModelSummaryAndRefresh(ctx, meta, currentPrompt, compacted, decision, contextWindow) - tokensAfter := estimatePromptTokensForModel(compacted, modelID) + tokensAfter := estimatePromptContextTokensForModel(compacted, modelID) if meta != nil { meta.CompactionCount++ oc.savePortalQuiet(ctx, portal, "compaction count") @@ -111,8 +109,8 @@ func (oc *AIClient) responseWithRetry( oc.emitCompactionStatus(ctx, portal, &CompactionEvent{ Type: CompactionEventEnd, SessionID: sessionID, - MessagesBefore: len(currentPrompt), - MessagesAfter: len(compacted), + MessagesBefore: PromptContextMessageCount(currentPrompt), + MessagesAfter: PromptContextMessageCount(compacted), TokensBefore: tokensBefore, TokensAfter: tokensAfter, Summary: summary, @@ -126,8 +124,8 @@ func (oc *AIClient) responseWithRetry( ContextWindowTokens: contextWindow, RequestedTokens: cle.RequestedTokens, PromptTokens: tokensAfter, - MessagesBefore: len(currentPrompt), - MessagesAfter: len(compacted), + MessagesBefore: PromptContextMessageCount(currentPrompt), + MessagesAfter: PromptContextMessageCount(compacted), TokensBefore: tokensBefore, TokensAfter: tokensAfter, DroppedCount: decision.DroppedCount, @@ -136,8 +134,8 @@ func (oc *AIClient) responseWithRetry( }, integrationruntime.CompactionLifecycleEnd, integrationruntime.CompactionLifecycleRefresh) oc.loggerForContext(ctx).Info(). - Int("messages_before", len(currentPrompt)). - Int("messages_after", len(compacted)). + Int("messages_before", PromptContextMessageCount(currentPrompt)). + Int("messages_after", PromptContextMessageCount(compacted)). Int("tokens_before", tokensBefore). Int("tokens_after", tokensAfter). Int("dropped", decision.DroppedCount). @@ -149,12 +147,12 @@ func (oc *AIClient) responseWithRetry( // Compaction was insufficient. Try an explicit tool-result truncation pass. truncatedPrompt, truncatedCount := oc.truncateOversizedToolResultsForOverflow(currentPrompt, contextWindow) if truncatedCount > 0 { - tokensAfter := estimatePromptTokensForModel(truncatedPrompt, modelID) + tokensAfter := estimatePromptContextTokensForModel(truncatedPrompt, modelID) oc.emitCompactionStatus(ctx, portal, &CompactionEvent{ Type: CompactionEventEnd, SessionID: sessionID, - MessagesBefore: len(currentPrompt), - MessagesAfter: len(truncatedPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), + MessagesAfter: PromptContextMessageCount(truncatedPrompt), TokensBefore: tokensBefore, TokensAfter: tokensAfter, Summary: fmt.Sprintf("Truncated %d oversized tool result(s).", truncatedCount), @@ -169,8 +167,8 @@ func (oc *AIClient) responseWithRetry( ContextWindowTokens: contextWindow, RequestedTokens: cle.RequestedTokens, PromptTokens: tokensAfter, - MessagesBefore: len(currentPrompt), - MessagesAfter: len(truncatedPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), + MessagesAfter: PromptContextMessageCount(truncatedPrompt), TokensBefore: tokensBefore, TokensAfter: tokensAfter, Reason: "truncate_oversized_tool_results", @@ -199,7 +197,7 @@ func (oc *AIClient) responseWithRetry( ContextWindowTokens: contextWindow, RequestedTokens: cle.RequestedTokens, PromptTokens: tokensBefore, - MessagesBefore: len(currentPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), TokensBefore: tokensBefore, Reason: "compaction did not reduce context sufficiently", Error: "compaction did not reduce context sufficiently", @@ -235,7 +233,7 @@ func (oc *AIClient) runCompactionPreflightFlushHook( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, attempt int, ) { if oc == nil || meta == nil { @@ -246,7 +244,7 @@ func (oc *AIClient) runCompactionPreflightFlushHook( contextWindow = 128000 } modelID := oc.effectiveModel(meta) - promptTokens := estimatePromptTokensForModel(prompt, modelID) + promptTokens := estimatePromptContextTokensForModel(prompt, modelID) projectedTokens := projectedCompactionFlushTokens(meta, promptTokens) oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ Client: oc, @@ -257,7 +255,7 @@ func (oc *AIClient) runCompactionPreflightFlushHook( ContextWindowTokens: contextWindow, RequestedTokens: projectedTokens, PromptTokens: promptTokens, - MessagesBefore: len(prompt), + MessagesBefore: PromptContextMessageCount(prompt), TokensBefore: promptTokens, }) oc.runCompactionFlushHook(ctx, portal, meta, prompt, &ContextLengthError{ @@ -313,7 +311,7 @@ func (oc *AIClient) runCompactionFlushHook( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, cle *ContextLengthError, attempt int, ) { @@ -342,7 +340,7 @@ func (oc *AIClient) runCompactionFlushHook( Client: oc, Portal: portal, Meta: meta, - Prompt: prompt, + Prompt: PromptContextToChatCompletionMessages(prompt, false), RequestedTokens: cle.RequestedTokens, ModelMaxTokens: cle.ModelMaxTokens, Attempt: attempt, @@ -356,9 +354,8 @@ func (oc *AIClient) runAgentLoopWithRetry( meta *PortalMetadata, promptContext PromptContext, ) { - prompt := oc.promptContextToDispatchMessages(ctx, portal, meta, promptContext) responseFn, logLabel := oc.selectAgentLoopRunFunc(meta, promptContext) - success, err := oc.responseWithRetry(ctx, evt, portal, meta, prompt, responseFn, logLabel) + success, err := oc.responseWithRetry(ctx, evt, portal, meta, promptContext, responseFn, logLabel) if success || err == nil { return } @@ -368,7 +365,7 @@ func (oc *AIClient) runAgentLoopWithRetry( oc.notifyMatrixSendFailure(ctx, portal, evt, err) } -func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext PromptContext) (responseFunc, string) { +func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext PromptContext) (responseFuncCanonical, string) { if HasUnsupportedResponsesPromptContext(promptContext) { return oc.runChatCompletionsAgentLoop, "chat_completions" } @@ -379,7 +376,7 @@ func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext P switch oc.resolveModelAPI(meta) { case ModelAPIChatCompletions: if isDirectOpenAIModel(modelID) { - return func(context.Context, *event.Event, *bridgev2.Portal, *PortalMetadata, []openai.ChatCompletionMessageParamUnion) (bool, *ContextLengthError, error) { + return func(context.Context, *event.Event, *bridgev2.Portal, *PortalMetadata, PromptContext) (bool, *ContextLengthError, error) { return false, nil, fmt.Errorf("invalid model configuration: direct OpenAI model %q cannot use chat_completions", modelID) }, "invalid_model_api" } @@ -415,13 +412,14 @@ func (oc *AIClient) notifyContextLengthExceeded( } func (oc *AIClient) runtimeCompactOnOverflow( - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, contextWindowTokens int, requestedTokens int, currentPromptTokens int, -) ([]openai.ChatCompletionMessageParamUnion, airuntime.CompactionDecision, bool) { +) (PromptContext, airuntime.CompactionDecision, bool) { + serialized := PromptContextToChatCompletionMessages(prompt, false) result := airuntime.CompactPromptOnOverflow(airuntime.OverflowCompactionInput{ - Prompt: prompt, + Prompt: serialized, ContextWindowTokens: contextWindowTokens, RequestedTokens: requestedTokens, CurrentPromptTokens: currentPromptTokens, @@ -434,14 +432,14 @@ func (oc *AIClient) runtimeCompactOnOverflow( MaxHistoryShare: oc.pruningMaxHistoryShare(), ProtectedTail: 3, }) - return result.Prompt, result.Decision, result.Success + return ChatMessagesToPromptContext(result.Prompt), result.Decision, result.Success } func (oc *AIClient) truncateOversizedToolResultsForOverflow( - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, contextWindowTokens int, -) ([]openai.ChatCompletionMessageParamUnion, int) { - if len(prompt) == 0 { +) (PromptContext, int) { + if len(prompt.Messages) == 0 { return prompt, 0 } cfg := oc.pruningConfigOrDefault() @@ -460,13 +458,13 @@ func (oc *AIClient) truncateOversizedToolResultsForOverflow( } } - out := slices.Clone(prompt) + out := ClonePromptContext(prompt) truncated := 0 - for i, msg := range out { - if msg.OfTool == nil { + for i, msg := range out.Messages { + if msg.Role != PromptRoleToolResult { continue } - content := airuntime.ExtractToolContent(msg.OfTool.Content) + content := strings.TrimSpace(msg.Text()) if len(content) <= thresholdChars { continue } @@ -474,7 +472,7 @@ func (oc *AIClient) truncateOversizedToolResultsForOverflow( if trimmed == content { continue } - out[i] = openai.ToolMessage(trimmed, msg.OfTool.ToolCallID) + out.Messages[i].Blocks = []PromptBlock{{Type: PromptBlockText, Text: trimmed}} truncated++ } return out, truncated diff --git a/bridges/ai/runtime_compaction_adapter.go b/bridges/ai/runtime_compaction_adapter.go index 33fc2ccc..d2155b8e 100644 --- a/bridges/ai/runtime_compaction_adapter.go +++ b/bridges/ai/runtime_compaction_adapter.go @@ -163,3 +163,7 @@ func estimatePromptTokensForModel(prompt []openai.ChatCompletionMessageParamUnio } return estimatePromptTokensFallback(prompt) } + +func estimatePromptContextTokensForModel(prompt PromptContext, model string) int { + return estimatePromptTokensForModel(PromptContextToChatCompletionMessages(prompt, false), model) +} diff --git a/bridges/ai/status_text.go b/bridges/ai/status_text.go index 0ab7b12b..4b9a747b 100644 --- a/bridges/ai/status_text.go +++ b/bridges/ai/status_text.go @@ -228,13 +228,8 @@ func (oc *AIClient) estimatePromptTokens(ctx context.Context, portal *bridgev2.P if err != nil { return 0 } - prompt := oc.promptContextToDispatchMessages(ctx, portal, meta, promptContext) modelID := oc.effectiveModel(meta) - count, err := EstimateTokens(prompt, modelID) - if err != nil { - return 0 - } - return count + return estimatePromptContextTokensForModel(promptContext, modelID) } func (oc *AIClient) getSessionEntryMaybe(ctx context.Context, agentID, sessionKey string) *sessionEntry { diff --git a/bridges/ai/streaming_chat_completions.go b/bridges/ai/streaming_chat_completions.go index 33fb8986..42032df2 100644 --- a/bridges/ai/streaming_chat_completions.go +++ b/bridges/ai/streaming_chat_completions.go @@ -49,7 +49,7 @@ func (a *chatCompletionsTurnAdapter) RunAgentTurn( typingSignals := a.typingSignals touchTyping := a.touchTyping isHeartbeat := a.isHeartbeat - currentMessages := a.messages + currentMessages := PromptContextToChatCompletionMessages(a.prompt, oc.isOpenRouterProvider()) params := oc.buildChatCompletionsAgentLoopParams(ctx, meta, currentMessages) @@ -139,33 +139,60 @@ func (a *chatCompletionsTurnAdapter) RunAgentTurn( if shouldContinueChatToolLoop(state.finishReason, len(toolCallParams)) { state.needsTextSeparator = true - assistantMsg := openai.ChatCompletionAssistantMessageParam{ - ToolCalls: toolCallParams, + assistantMsg := PromptMessage{ + Role: PromptRoleAssistant, } if content := strings.TrimSpace(roundContent.String()); content != "" { - assistantMsg.Content.OfString = param.NewOpt(content) + assistantMsg.Blocks = append(assistantMsg.Blocks, PromptBlock{ + Type: PromptBlockText, + Text: content, + }) + } + for _, toolCall := range toolCallParams { + if toolCall.OfFunction == nil { + continue + } + assistantMsg.Blocks = append(assistantMsg.Blocks, PromptBlock{ + Type: PromptBlockToolCall, + ToolCallID: toolCall.OfFunction.ID, + ToolName: toolCall.OfFunction.Function.Name, + ToolCallArguments: toolCall.OfFunction.Function.Arguments, + }) + } + if len(assistantMsg.Blocks) > 0 { + a.prompt.Messages = append(a.prompt.Messages, assistantMsg) } - currentMessages = append(currentMessages, openai.ChatCompletionMessageParamUnion{OfAssistant: &assistantMsg}) for _, output := range state.pendingFunctionOutputs { - currentMessages = append(currentMessages, openai.ToolMessage(output.output, output.callID)) + a.prompt.Messages = append(a.prompt.Messages, PromptMessage{ + Role: PromptRoleToolResult, + ToolCallID: output.callID, + ToolName: output.name, + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: output.output, + }}, + }) } - currentMessages = append(currentMessages, buildSteeringUserMessages(steeringPrompts)...) + a.prompt.Messages = append(a.prompt.Messages, buildSteeringPromptMessages(steeringPrompts)...) if round >= maxAgentLoopToolTurns { log.Warn().Int("rounds", round+1).Msg("Max tool call rounds reached; stopping chat completions continuation") - currentMessages = append(currentMessages, openai.AssistantMessage("Continuation stopped after reaching the maximum number of streaming tool rounds.")) + a.prompt.Messages = append(a.prompt.Messages, PromptMessage{ + Role: PromptRoleAssistant, + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: "Continuation stopped after reaching the maximum number of streaming tool rounds.", + }}, + }) state.clearContinuationState() - a.messages = currentMessages return false, nil, nil } // Chat Completions does not support MCP approvals; clearContinuationState // is safe here — it resets pendingFunctionOutputs (consumed above) and // pendingMcpApprovals (always empty for Chat). state.clearContinuationState() - a.messages = currentMessages return true, nil, nil } - a.messages = currentMessages return false, nil, nil } @@ -195,6 +222,16 @@ func (oc *AIClient) runChatCompletionsAgentLoop( portal *bridgev2.Portal, meta *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion, +) (bool, *ContextLengthError, error) { + return oc.runChatCompletionsAgentLoopPrompt(ctx, evt, portal, meta, ChatMessagesToPromptContext(messages)) +} + +func (oc *AIClient) runChatCompletionsAgentLoopPrompt( + ctx context.Context, + evt *event.Event, + portal *bridgev2.Portal, + meta *PortalMetadata, + prompt PromptContext, ) (bool, *ContextLengthError, error) { portalID := "" if portal != nil { @@ -205,9 +242,9 @@ func (oc *AIClient) runChatCompletionsAgentLoop( Str("portal", portalID). Logger() - return oc.runAgentLoop(ctx, log, evt, portal, meta, messages, func(prep streamingRunPrep, pruned []openai.ChatCompletionMessageParamUnion) agentLoopProvider { + return oc.runAgentLoop(ctx, log, evt, portal, meta, prompt, func(prep streamingRunPrep, prompt PromptContext) agentLoopProvider { return &chatCompletionsTurnAdapter{ - agentLoopProviderBase: newAgentLoopProviderBase(oc, log, portal, meta, prep, pruned), + agentLoopProviderBase: newAgentLoopProviderBase(oc, log, portal, meta, prep, prompt), } }) } diff --git a/bridges/ai/streaming_executor.go b/bridges/ai/streaming_executor.go index 6a4811eb..fb9de84a 100644 --- a/bridges/ai/streaming_executor.go +++ b/bridges/ai/streaming_executor.go @@ -14,8 +14,8 @@ import ( type agentLoopProvider interface { TrackRoomRunStreaming() bool RunAgentTurn(ctx context.Context, evt *event.Event, round int) (continueLoop bool, cle *ContextLengthError, err error) - GetFollowUpMessages(ctx context.Context) []openai.ChatCompletionMessageParamUnion - ContinueAgentLoop(messages []openai.ChatCompletionMessageParamUnion) + GetFollowUpMessages(ctx context.Context) []PromptMessage + ContinueAgentLoop(messages []PromptMessage) FinalizeAgentLoop(ctx context.Context) } @@ -28,7 +28,7 @@ type agentLoopProviderBase struct { typingSignals *TypingSignaler touchTyping func() isHeartbeat bool - messages []openai.ChatCompletionMessageParamUnion + prompt PromptContext } func newAgentLoopProviderBase( @@ -37,7 +37,7 @@ func newAgentLoopProviderBase( portal *bridgev2.Portal, meta *PortalMetadata, prep streamingRunPrep, - messages []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) agentLoopProviderBase { return agentLoopProviderBase{ oc: oc, @@ -48,22 +48,22 @@ func newAgentLoopProviderBase( typingSignals: prep.TypingSignals, touchTyping: prep.TouchTyping, isHeartbeat: prep.IsHeartbeat, - messages: messages, + prompt: prompt, } } -func (a *agentLoopProviderBase) GetFollowUpMessages(context.Context) []openai.ChatCompletionMessageParamUnion { +func (a *agentLoopProviderBase) GetFollowUpMessages(context.Context) []PromptMessage { if a == nil || a.oc == nil || a.state == nil { return nil } return a.oc.getFollowUpMessages(a.state.roomID) } -func (a *agentLoopProviderBase) ContinueAgentLoop(messages []openai.ChatCompletionMessageParamUnion) { +func (a *agentLoopProviderBase) ContinueAgentLoop(messages []PromptMessage) { if a == nil || len(messages) == 0 { return } - a.messages = append(a.messages, messages...) + a.prompt.Messages = append(a.prompt.Messages, messages...) } func (oc *AIClient) runAgentLoop( @@ -72,14 +72,15 @@ func (oc *AIClient) runAgentLoop( evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, - newProvider func(prep streamingRunPrep, pruned []openai.ChatCompletionMessageParamUnion) agentLoopProvider, + prompt PromptContext, + newProvider func(prep streamingRunPrep, prompt PromptContext) agentLoopProvider, ) (bool, *ContextLengthError, error) { - prep, pruned, typingCleanup := oc.prepareStreamingRun(ctx, log, evt, portal, meta, messages) + messages := PromptContextToChatCompletionMessages(prompt, oc.isOpenRouterProvider()) + prep, _, typingCleanup := oc.prepareStreamingRun(ctx, log, evt, portal, meta, messages) defer typingCleanup() state := prep.State - provider := newProvider(prep, pruned) + provider := newProvider(prep, prompt) if state.roomID != "" { if provider.TrackRoomRunStreaming() { oc.markRoomRunStreaming(state.roomID, true) diff --git a/bridges/ai/streaming_responses_api.go b/bridges/ai/streaming_responses_api.go index b05585a6..566d4e2c 100644 --- a/bridges/ai/streaming_responses_api.go +++ b/bridges/ai/streaming_responses_api.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/packages/param" "github.com/openai/openai-go/v3/packages/ssestream" "github.com/openai/openai-go/v3/responses" @@ -38,16 +37,15 @@ func (a *responsesTurnAdapter) TrackRoomRunStreaming() bool { func (a *responsesTurnAdapter) startInitialRound(ctx context.Context) (*ssestream.Stream[responses.ResponseStreamEventUnion], error) { if !a.initialized { - promptContext := ChatMessagesToPromptContext(a.messages) - input := PromptContextToResponsesInput(promptContext) - a.params = a.oc.buildResponsesAgentLoopParams(ctx, a.meta, promptContext.SystemPrompt, input, false) + input := PromptContextToResponsesInput(a.prompt) + a.params = a.oc.buildResponsesAgentLoopParams(ctx, a.meta, a.prompt.SystemPrompt, input, false) if len(a.params.Tools) > 0 { zerolog.Ctx(ctx).Debug().Int("count", len(a.params.Tools)).Msg("Added streaming turn tools") } if a.oc.isOpenRouterProvider() { ctx = WithPDFEngine(ctx, a.oc.effectivePDFEngine(a.meta)) } - a.state.baseSystemPrompt = promptContext.SystemPrompt + a.state.baseSystemPrompt = a.prompt.SystemPrompt a.initialized = true } stream := a.oc.api.Responses.NewStreaming(ctx, a.params) @@ -122,7 +120,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( stream, err = a.startInitialRound(ctx) params = a.params if err != nil { - logResponsesFailure(a.log, err, params, a.meta, a.messages, "stream_init") + logResponsesFailure(a.log, err, params, a.meta, PromptContextToChatCompletionMessages(a.prompt, a.oc.isOpenRouterProvider()), "stream_init") return false, nil, &PreDeltaError{Err: err} } } else { @@ -144,7 +142,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( if errors.Is(err, context.Canceled) { return false, nil, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, state, a.meta, "cancelled", err) } - logResponsesFailure(a.log, err, params, a.meta, a.messages, "continuation_init") + logResponsesFailure(a.log, err, params, a.meta, PromptContextToChatCompletionMessages(a.prompt, a.oc.isOpenRouterProvider()), "continuation_init") return false, nil, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, state, a.meta, "error", err) } } @@ -160,7 +158,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( if round > 0 { stage = "continuation_event_error" } - logResponsesFailure(a.log, evtErr, params, a.meta, a.messages, stage) + logResponsesFailure(a.log, evtErr, params, a.meta, PromptContextToChatCompletionMessages(a.prompt, a.oc.isOpenRouterProvider()), stage) } return done, cle, evtErr }, @@ -169,7 +167,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( if round > 0 { stage = "continuation_err" } - logResponsesFailure(a.log, stepErr, params, a.meta, a.messages, stage) + logResponsesFailure(a.log, stepErr, params, a.meta, PromptContextToChatCompletionMessages(a.prompt, a.oc.isOpenRouterProvider()), stage) return a.oc.handleResponsesStreamErr(ctx, a.portal, state, a.meta, stepErr, round == 0) }, ) @@ -187,13 +185,12 @@ func (a *responsesTurnAdapter) FinalizeAgentLoop(ctx context.Context) { a.oc.finalizeResponsesStream(ctx, a.log, a.portal, a.state, a.meta) } -func (a *responsesTurnAdapter) ContinueAgentLoop(messages []openai.ChatCompletionMessageParamUnion) { +func (a *responsesTurnAdapter) ContinueAgentLoop(messages []PromptMessage) { if len(messages) == 0 { return } - a.messages = append(a.messages, messages...) - promptContext := ChatMessagesToPromptContext(messages) - a.state.baseInput = append(a.state.baseInput, PromptContextToResponsesInput(promptContext)...) + a.prompt.Messages = append(a.prompt.Messages, messages...) + a.state.baseInput = append(a.state.baseInput, PromptContextToResponsesInput(PromptContext{Messages: messages})...) a.hasFollowUp = true } @@ -473,6 +470,16 @@ func (oc *AIClient) runResponsesAgentLoop( portal *bridgev2.Portal, meta *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion, +) (bool, *ContextLengthError, error) { + return oc.runResponsesAgentLoopPrompt(ctx, evt, portal, meta, ChatMessagesToPromptContext(messages)) +} + +func (oc *AIClient) runResponsesAgentLoopPrompt( + ctx context.Context, + evt *event.Event, + portal *bridgev2.Portal, + meta *PortalMetadata, + prompt PromptContext, ) (bool, *ContextLengthError, error) { portalID := "" if portal != nil { @@ -481,8 +488,8 @@ func (oc *AIClient) runResponsesAgentLoop( log := zerolog.Ctx(ctx).With(). Str("portal_id", portalID). Logger() - return oc.runAgentLoop(ctx, log, evt, portal, meta, messages, func(prep streamingRunPrep, pruned []openai.ChatCompletionMessageParamUnion) agentLoopProvider { - base := newAgentLoopProviderBase(oc, log, portal, meta, prep, pruned) + return oc.runAgentLoop(ctx, log, evt, portal, meta, prompt, func(prep streamingRunPrep, prompt PromptContext) agentLoopProvider { + base := newAgentLoopProviderBase(oc, log, portal, meta, prep, prompt) return &responsesTurnAdapter{ agentLoopProviderBase: base, rsc: &responseStreamContext{ diff --git a/bridges/ai/subagent_announce.go b/bridges/ai/subagent_announce.go index cd73675d..c09f4d91 100644 --- a/bridges/ai/subagent_announce.go +++ b/bridges/ai/subagent_announce.go @@ -7,7 +7,6 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/id" @@ -142,9 +141,9 @@ func (oc *AIClient) runSubagentCompletion( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) (bool, error) { - responseFn, logLabel := oc.selectAgentLoopRunFunc(meta, ChatMessagesToPromptContext(prompt)) + responseFn, logLabel := oc.selectAgentLoopRunFunc(meta, prompt) return oc.responseWithRetry(ctx, nil, portal, meta, prompt, responseFn, logLabel) } @@ -153,7 +152,7 @@ func (oc *AIClient) runSubagentAndAnnounce( run *subagentRun, childPortal *bridgev2.Portal, childMeta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) { if oc == nil || run == nil || childPortal == nil || childMeta == nil { return diff --git a/bridges/ai/subagent_spawn.go b/bridges/ai/subagent_spawn.go index bdadca23..241efbcf 100644 --- a/bridges/ai/subagent_spawn.go +++ b/bridges/ai/subagent_spawn.go @@ -338,8 +338,6 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P "error": err.Error(), }), nil } - promptMessages := oc.promptContextToDispatchMessages(ctx, childPortal, childMeta, promptContext) - userMessage := &database.Message{ ID: agentremote.MatrixMessageID(eventID), MXID: eventID, @@ -370,7 +368,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P Timeout: runTimeout, } oc.registerSubagentRun(run) - oc.startSubagentRun(ctx, run, childPortal, childMeta, promptMessages) + oc.startSubagentRun(ctx, run, childPortal, childMeta, promptContext) payload := map[string]any{ "status": "accepted", @@ -392,7 +390,7 @@ func (oc *AIClient) startSubagentRun( run *subagentRun, childPortal *bridgev2.Portal, childMeta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) { if run == nil || childPortal == nil || childMeta == nil { return diff --git a/bridges/ai/system_prompts.go b/bridges/ai/system_prompts.go index 5162772c..04d1a41c 100644 --- a/bridges/ai/system_prompts.go +++ b/bridges/ai/system_prompts.go @@ -4,7 +4,6 @@ import ( "context" "strings" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" runtimeparse "github.com/beeper/agentremote/pkg/runtime" @@ -55,38 +54,19 @@ func buildSessionIdentityHint(portal *bridgev2.Portal, _ *PortalMetadata) string return "sessionKey: " + session } -func (oc *AIClient) buildAdditionalSystemPrompts( +type memoryPromptAugmentor interface { + PromptContextText(ctx context.Context, portal any, meta any) string +} + +func (oc *AIClient) buildAdditionalSystemPromptText( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, -) []openai.ChatCompletionMessageParamUnion { - return oc.additionalSystemMessages(ctx, portal, meta) -} - -func systemMessageText(messages []openai.ChatCompletionMessageParamUnion) string { - var parts []string - for _, msg := range messages { - if msg.OfSystem == nil { - continue - } - if text := strings.TrimSpace(msg.OfSystem.Content.OfString.Value); text != "" { - parts = append(parts, text) - continue - } - if len(msg.OfSystem.Content.OfArrayOfContentParts) == 0 { - continue - } - var lines []string - for _, part := range msg.OfSystem.Content.OfArrayOfContentParts { - if text := strings.TrimSpace(part.Text); text != "" { - lines = append(lines, text) - } - } - if len(lines) > 0 { - parts = append(parts, strings.Join(lines, "\n")) - } - } - return strings.TrimSpace(strings.Join(parts, "\n\n")) +) string { + return joinPromptFragments( + oc.buildAdditionalSystemPromptCoreText(ctx, portal, meta), + oc.buildMemoryPromptContextText(ctx, portal, meta), + ) } func (oc *AIClient) buildSystemPromptText( @@ -98,14 +78,7 @@ func (oc *AIClient) buildSystemPromptText( if base == "" { base = oc.effectivePrompt(meta) } - fragments := []string{base, systemMessageText(oc.buildAdditionalSystemPrompts(ctx, portal, meta))} - var parts []string - for _, fragment := range fragments { - if text := strings.TrimSpace(fragment); text != "" { - parts = append(parts, text) - } - } - return strings.TrimSpace(strings.Join(parts, "\n\n")) + return joinPromptFragments(base, oc.buildAdditionalSystemPromptText(ctx, portal, meta)) } func (oc *AIClient) buildConversationSystemPromptText( @@ -121,34 +94,50 @@ func (oc *AIClient) buildConversationSystemPromptText( return joinPromptFragments(sessionGreetingFragment(ctx, portal, meta, oc.log), base) } -func (oc *AIClient) buildAdditionalSystemPromptsCore( +func (oc *AIClient) buildAdditionalSystemPromptCoreText( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, -) []openai.ChatCompletionMessageParamUnion { - var out []openai.ChatCompletionMessageParamUnion +) string { + var out []string if meta != nil && portal != nil && oc.isGroupChat(ctx, portal) { activation := oc.resolveGroupActivation(meta) intro := buildGroupIntro(oc.matrixRoomDisplayName(ctx, portal), activation) if strings.TrimSpace(intro) != "" { - out = append(out, openai.SystemMessage(intro)) + out = append(out, intro) } } if meta != nil { if verboseHint := buildVerboseSystemHint(meta); verboseHint != "" { - out = append(out, openai.SystemMessage(verboseHint)) + out = append(out, verboseHint) } } if accountHint := oc.buildDesktopAccountHintPrompt(ctx); accountHint != "" { - out = append(out, openai.SystemMessage(accountHint)) + out = append(out, accountHint) } if ident := buildSessionIdentityHint(portal, meta); ident != "" { - out = append(out, openai.SystemMessage(ident)) + out = append(out, ident) } - return out + return joinPromptFragments(out...) +} + +func (oc *AIClient) buildMemoryPromptContextText( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, +) string { + if oc == nil || len(oc.integrationModules) == 0 { + return "" + } + module := oc.integrationModules["memory"] + augmentor, ok := module.(memoryPromptAugmentor) + if !ok || augmentor == nil { + return "" + } + return strings.TrimSpace(augmentor.PromptContextText(ctx, portal, meta)) } diff --git a/pkg/integrations/memory/integration.go b/pkg/integrations/memory/integration.go index 7ba5fb35..b7357bb0 100644 --- a/pkg/integrations/memory/integration.go +++ b/pkg/integrations/memory/integration.go @@ -7,7 +7,6 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "go.mau.fi/util/dbutil" "github.com/beeper/agentremote/pkg/agents" @@ -26,8 +25,8 @@ type ProviderStatus = memorycore.ProviderStatus type ResolvedConfig = memorycore.ResolvedConfig // Integration is the self-owned memory integration module. -// It implements ToolIntegration, PromptIntegration, CommandIntegration, -// EventIntegration, LoginPurgeIntegration, and LoginLifecycleIntegration +// It implements ToolIntegration, CommandIntegration, EventIntegration, +// LoginPurgeIntegration, and LoginLifecycleIntegration // directly, wiring all deps from Host // capability interfaces. type Integration struct { @@ -78,12 +77,8 @@ func (i *Integration) ToolAvailability(_ context.Context, scope iruntime.ToolSco return true, true, iruntime.SourceGlobalDefault, "" } -func (i *Integration) AdditionalSystemMessages(_ context.Context, _ iruntime.PromptScope) []openai.ChatCompletionMessageParamUnion { - return nil -} - -func (i *Integration) AugmentPrompt(ctx context.Context, scope iruntime.PromptScope, prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion { - return AugmentPrompt(ctx, scope, prompt, PromptAugmentDeps{ +func (i *Integration) PromptContextText(ctx context.Context, portal any, meta any) string { + return BuildPromptContextText(ctx, portal, meta, PromptContextDeps{ ShouldInjectContext: i.shouldInjectMemoryPromptContext, ShouldBootstrap: i.shouldBootstrapMemoryPromptContext, ResolveBootstrapPaths: i.resolveMemoryBootstrapPaths, @@ -279,7 +274,7 @@ func (i *Integration) buildOverflowDeps() OverflowDeps { } } -func (i *Integration) shouldInjectMemoryPromptContext(scope iruntime.PromptScope) bool { +func (i *Integration) shouldInjectMemoryPromptContext(_ any, _ any) bool { if cfg := i.host.ModuleConfig(moduleName); cfg != nil { inject, _ := cfg["inject_context"].(bool) return inject @@ -287,15 +282,15 @@ func (i *Integration) shouldInjectMemoryPromptContext(scope iruntime.PromptScope return false } -func (i *Integration) shouldBootstrapMemoryPromptContext(scope iruntime.PromptScope) bool { - raw := i.host.GetModuleMeta(scope.Meta, "memory_bootstrap_at") +func (i *Integration) shouldBootstrapMemoryPromptContext(_ any, meta any) bool { + raw := i.host.GetModuleMeta(meta, "memory_bootstrap_at") if raw == nil { return true } return toInt64(raw) == 0 } -func (i *Integration) resolveMemoryBootstrapPaths(_ iruntime.PromptScope) []string { +func (i *Integration) resolveMemoryBootstrapPaths(_ any, _ any) []string { _, loc := i.host.UserTimezone() if loc == nil { loc = time.UTC @@ -309,18 +304,18 @@ func (i *Integration) resolveMemoryBootstrapPaths(_ iruntime.PromptScope) []stri } } -func (i *Integration) markMemoryPromptBootstrapped(ctx context.Context, scope iruntime.PromptScope) { - if scope.Portal == nil || scope.Meta == nil { +func (i *Integration) markMemoryPromptBootstrapped(ctx context.Context, portal any, meta any) { + if portal == nil || meta == nil { return } - i.host.SetModuleMeta(scope.Meta, "memory_bootstrap_at", time.Now().UnixMilli()) - _ = i.host.SavePortal(ctx, scope.Portal, "memory bootstrap") + i.host.SetModuleMeta(meta, "memory_bootstrap_at", time.Now().UnixMilli()) + _ = i.host.SavePortal(ctx, portal, "memory bootstrap") } -func (i *Integration) readMemoryPromptSection(ctx context.Context, scope iruntime.PromptScope, path string) string { +func (i *Integration) readMemoryPromptSection(ctx context.Context, meta any, path string) string { agentID := "" - if scope.Meta != nil { - agentID = i.host.AgentIDFromMeta(scope.Meta) + if meta != nil { + agentID = i.host.AgentIDFromMeta(meta) } content, filePath, found, err := i.host.ReadTextFile(ctx, agentID, path) if err != nil || !found { diff --git a/pkg/integrations/memory/prompt_exec.go b/pkg/integrations/memory/prompt_exec.go index 5a480df9..8ad0bc63 100644 --- a/pkg/integrations/memory/prompt_exec.go +++ b/pkg/integrations/memory/prompt_exec.go @@ -2,60 +2,52 @@ package memory import ( "context" - "slices" "strings" - - "github.com/openai/openai-go/v3" - - iruntime "github.com/beeper/agentremote/pkg/integrations/runtime" ) -type PromptAugmentDeps struct { - ShouldInjectContext func(scope iruntime.PromptScope) bool - ShouldBootstrap func(scope iruntime.PromptScope) bool - ResolveBootstrapPaths func(scope iruntime.PromptScope) []string - MarkBootstrapped func(ctx context.Context, scope iruntime.PromptScope) - ReadSection func(ctx context.Context, scope iruntime.PromptScope, path string) string +type PromptContextDeps struct { + ShouldInjectContext func(portal any, meta any) bool + ShouldBootstrap func(portal any, meta any) bool + ResolveBootstrapPaths func(portal any, meta any) []string + MarkBootstrapped func(ctx context.Context, portal any, meta any) + ReadSection func(ctx context.Context, meta any, path string) string } -func AugmentPrompt( +func BuildPromptContextText( ctx context.Context, - scope iruntime.PromptScope, - prompt []openai.ChatCompletionMessageParamUnion, - deps PromptAugmentDeps, -) []openai.ChatCompletionMessageParamUnion { - if deps.ShouldInjectContext == nil || !deps.ShouldInjectContext(scope) { - return prompt + portal any, + meta any, + deps PromptContextDeps, +) string { + if deps.ShouldInjectContext == nil || !deps.ShouldInjectContext(portal, meta) { + return "" } if deps.ReadSection == nil { - return prompt + return "" } sections := make([]string, 0, 3) - if section := deps.ReadSection(ctx, scope, "MEMORY.md"); section != "" { + if section := deps.ReadSection(ctx, meta, "MEMORY.md"); section != "" { sections = append(sections, section) - } else if section := deps.ReadSection(ctx, scope, "memory.md"); section != "" { + } else if section := deps.ReadSection(ctx, meta, "memory.md"); section != "" { sections = append(sections, section) } - if deps.ShouldBootstrap != nil && deps.ShouldBootstrap(scope) { + if deps.ShouldBootstrap != nil && deps.ShouldBootstrap(portal, meta) { if deps.ResolveBootstrapPaths != nil { - for _, path := range deps.ResolveBootstrapPaths(scope) { - if section := deps.ReadSection(ctx, scope, path); section != "" { + for _, path := range deps.ResolveBootstrapPaths(portal, meta) { + if section := deps.ReadSection(ctx, meta, path); section != "" { sections = append(sections, section) } } } if deps.MarkBootstrapped != nil { - deps.MarkBootstrapped(ctx, scope) + deps.MarkBootstrapped(ctx, portal, meta) } } if len(sections) == 0 { - return prompt + return "" } - contextText := strings.Join(sections, "\n\n") - out := slices.Clone(prompt) - out = append(out, openai.SystemMessage(contextText)) - return out + return strings.Join(sections, "\n\n") } diff --git a/pkg/integrations/runtime/interfaces.go b/pkg/integrations/runtime/interfaces.go index 4c5672b6..61d0c527 100644 --- a/pkg/integrations/runtime/interfaces.go +++ b/pkg/integrations/runtime/interfaces.go @@ -2,8 +2,6 @@ package runtime import ( "context" - - "github.com/openai/openai-go/v3" ) // SettingSource indicates where a setting value came from. @@ -42,13 +40,6 @@ type ToolCall struct { Scope ToolScope } -// PromptScope carries prompt-building context without coupling to connector internals. -type PromptScope struct { - Client any - Portal any - Meta any -} - // ToolIntegration is the pluggable surface for tool definitions/availability/execution. type ToolIntegration interface { Name() string @@ -57,13 +48,6 @@ type ToolIntegration interface { ToolAvailability(ctx context.Context, scope ToolScope, toolName string) (known bool, available bool, source SettingSource, reason string) } -// PromptIntegration is the pluggable surface for prompt/system message augmentation. -type PromptIntegration interface { - Name() string - AdditionalSystemMessages(ctx context.Context, scope PromptScope) []openai.ChatCompletionMessageParamUnion - AugmentPrompt(ctx context.Context, scope PromptScope, prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion -} - // ToolApprovalIntegration is an optional seam for tool approval policy overrides. type ToolApprovalIntegration interface { Name() string From ca8d7935d500d4a7c1668949a58742cdb5e379af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 22:03:04 +0200 Subject: [PATCH 04/23] Migrate prompts to PromptMessage API Replace usages of openai ChatCompletion message types with the internal PromptMessage API (PromptRoleUser/Text()). Update tests to assert Role/Text() and rename buildSteeringUserMessages -> buildSteeringPromptMessages. Remove the promptIntegrationRegistry, corePromptIntegration, and related prompt augmentation/additional message helpers and promptRegistry field; core integrations now only register tool integrations. Switch agent-loop selection and callers to use the *Prompt variants (runChatCompletionsAgentLoopPrompt / runResponsesAgentLoopPrompt) and remove conversions that produced openai.ChatCompletionMessageParamUnion. Adjust imports accordingly. --- bridges/ai/agent_loop_steering_test.go | 44 ++++----- bridges/ai/client.go | 15 --- bridges/ai/integrations.go | 112 +---------------------- bridges/ai/integrations_test.go | 60 ------------ bridges/ai/pending_queue.go | 16 ---- bridges/ai/prompt_context_ops.go | 1 - bridges/ai/response_retry.go | 6 +- bridges/ai/streaming_chat_completions.go | 10 -- bridges/ai/streaming_responses_api.go | 10 -- bridges/ai/subagent_spawn.go | 1 - pkg/integrations/memory/integration.go | 1 + 11 files changed, 27 insertions(+), 249 deletions(-) diff --git a/bridges/ai/agent_loop_steering_test.go b/bridges/ai/agent_loop_steering_test.go index 1a5330b5..703f6a87 100644 --- a/bridges/ai/agent_loop_steering_test.go +++ b/bridges/ai/agent_loop_steering_test.go @@ -53,15 +53,15 @@ func TestGetSteeringMessages_FiltersAndDrainsQueue(t *testing.T) { } func TestBuildSteeringUserMessages(t *testing.T) { - got := buildSteeringUserMessages([]string{"first", " ", "second"}) + got := buildSteeringPromptMessages([]string{"first", " ", "second"}) if len(got) != 2 { - t.Fatalf("expected 2 steering user messages, got %d", len(got)) + t.Fatalf("expected 2 steering prompt messages, got %d", len(got)) } - if got[0].OfUser == nil || got[0].OfUser.Content.OfString.Value != "first" { - t.Fatalf("unexpected first steering user message: %#v", got[0]) + if got[0].Role != PromptRoleUser || got[0].Text() != "first" { + t.Fatalf("unexpected first steering prompt message: %#v", got[0]) } - if got[1].OfUser == nil || got[1].OfUser.Content.OfString.Value != "second" { - t.Fatalf("unexpected second steering user message: %#v", got[1]) + if got[1].Role != PromptRoleUser || got[1].Text() != "second" { + t.Fatalf("unexpected second steering prompt message: %#v", got[1]) } } @@ -79,7 +79,7 @@ func TestGetFollowUpMessages_ConsumesSingleQueuedTextMessage(t *testing.T) { } messages := oc.getFollowUpMessages(roomID) - if len(messages) != 1 || messages[0].OfUser == nil || messages[0].OfUser.Content.OfString.Value != "follow up" { + if len(messages) != 1 || messages[0].Role != PromptRoleUser || messages[0].Text() != "follow up" { t.Fatalf("unexpected follow-up messages: %#v", messages) } if snapshot := oc.getQueueSnapshot(roomID); snapshot != nil { @@ -102,11 +102,11 @@ func TestGetFollowUpMessages_CollectsQueuedTextMessages(t *testing.T) { } messages := oc.getFollowUpMessages(roomID) - if len(messages) != 1 || messages[0].OfUser == nil { + if len(messages) != 1 || messages[0].Role != PromptRoleUser { t.Fatalf("expected one combined follow-up message, got %#v", messages) } - if messages[0].OfUser.Content.OfString.Value != "[Queued messages while agent was busy]\n\n---\nQueued #1\nfirst\n\n---\nQueued #2\nsecond" { - t.Fatalf("unexpected combined follow-up prompt: %q", messages[0].OfUser.Content.OfString.Value) + if messages[0].Text() != "[Queued messages while agent was busy]\n\n---\nQueued #1\nfirst\n\n---\nQueued #2\nsecond" { + t.Fatalf("unexpected combined follow-up prompt: %q", messages[0].Text()) } } @@ -128,11 +128,11 @@ func TestGetFollowUpMessages_CollectSummaryIsConsumed(t *testing.T) { } messages := oc.getFollowUpMessages(roomID) - if len(messages) != 1 || messages[0].OfUser == nil { + if len(messages) != 1 || messages[0].Role != PromptRoleUser { t.Fatalf("expected one combined follow-up message, got %#v", messages) } - if messages[0].OfUser.Content.OfString.Value != "[Queued messages while agent was busy]\n\n[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two\n\n---\nQueued #1\nfirst\n\n---\nQueued #2\nsecond" { - t.Fatalf("unexpected combined follow-up prompt with summary: %q", messages[0].OfUser.Content.OfString.Value) + if messages[0].Text() != "[Queued messages while agent was busy]\n\n[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two\n\n---\nQueued #1\nfirst\n\n---\nQueued #2\nsecond" { + t.Fatalf("unexpected combined follow-up prompt with summary: %q", messages[0].Text()) } if again := oc.getFollowUpMessages(roomID); len(again) != 0 { @@ -160,11 +160,11 @@ func TestGetFollowUpMessages_UsesSyntheticSummaryPrompt(t *testing.T) { } messages := oc.getFollowUpMessages(roomID) - if len(messages) != 1 || messages[0].OfUser == nil { + if len(messages) != 1 || messages[0].Role != PromptRoleUser { t.Fatalf("expected one synthetic follow-up message, got %#v", messages) } - if messages[0].OfUser.Content.OfString.Value != "[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two" { - t.Fatalf("unexpected synthetic follow-up prompt: %q", messages[0].OfUser.Content.OfString.Value) + if messages[0].Text() != "[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two" { + t.Fatalf("unexpected synthetic follow-up prompt: %q", messages[0].Text()) } } @@ -185,19 +185,19 @@ func TestGetFollowUpMessages_SyntheticSummaryIsConsumedBeforeLatestMessage(t *te } first := oc.getFollowUpMessages(roomID) - if len(first) != 1 || first[0].OfUser == nil { + if len(first) != 1 || first[0].Role != PromptRoleUser { t.Fatalf("expected one synthetic follow-up message, got %#v", first) } - if first[0].OfUser.Content.OfString.Value != "[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two" { - t.Fatalf("unexpected first synthetic follow-up prompt: %q", first[0].OfUser.Content.OfString.Value) + if first[0].Text() != "[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two" { + t.Fatalf("unexpected first synthetic follow-up prompt: %q", first[0].Text()) } second := oc.getFollowUpMessages(roomID) - if len(second) != 1 || second[0].OfUser == nil { + if len(second) != 1 || second[0].Role != PromptRoleUser { t.Fatalf("expected queued latest message after summary, got %#v", second) } - if second[0].OfUser.Content.OfString.Value != "latest" { - t.Fatalf("expected latest queued message after consuming summary, got %q", second[0].OfUser.Content.OfString.Value) + if second[0].Text() != "latest" { + t.Fatalf("expected latest queued message after consuming summary, got %q", second[0].Text()) } if third := oc.getFollowUpMessages(roomID); len(third) != 0 { diff --git a/bridges/ai/client.go b/bridges/ai/client.go index 54386239..70008d7b 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -315,7 +315,6 @@ type AIClient struct { integrationOrder []string toolRegistry *toolIntegrationRegistry - promptRegistry *promptIntegrationRegistry commandRegistry *commandIntegrationRegistry eventRegistry *eventIntegrationRegistry purgeRegistry *purgeIntegrationRegistry @@ -1683,20 +1682,6 @@ func (oc *AIClient) updateAssistantGeneratedFiles(ctx context.Context, portal *b oc.Log().Warn().Msg("No assistant message found to update with async GeneratedFiles") } -func (oc *AIClient) promptContextToDispatchMessages( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, - promptContext PromptContext, -) []openai.ChatCompletionMessageParamUnion { - promptMessages := PromptContextToChatCompletionMessages(promptContext, oc.isOpenRouterProvider()) - promptMessages = oc.augmentPromptWithIntegrations(ctx, portal, meta, promptMessages) - if meta != nil && IsGoogleModel(oc.effectiveModel(meta)) { - promptMessages = SanitizeGoogleTurnOrdering(promptMessages) - } - return promptMessages -} - type historyLoadResult struct { rows []*database.Message hasVision bool diff --git a/bridges/ai/integrations.go b/bridges/ai/integrations.go index fadb3283..9f8e9177 100644 --- a/bridges/ai/integrations.go +++ b/bridges/ai/integrations.go @@ -9,7 +9,6 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "github.com/beeper/agentremote" @@ -80,46 +79,6 @@ func (r *toolIntegrationRegistry) availability( return false, false, SourceGlobalDefault, "" } -type promptIntegrationRegistry struct { - items []integrationruntime.PromptIntegration -} - -func (r *promptIntegrationRegistry) register(integration integrationruntime.PromptIntegration) { - if integration == nil { - return - } - r.items = append(r.items, integration) -} - -func (r *promptIntegrationRegistry) additionalMessages( - ctx context.Context, - scope integrationruntime.PromptScope, -) []openai.ChatCompletionMessageParamUnion { - if r == nil { - return nil - } - var out []openai.ChatCompletionMessageParamUnion - for _, integration := range r.items { - out = append(out, integration.AdditionalSystemMessages(ctx, scope)...) - } - return out -} - -func (r *promptIntegrationRegistry) augmentPrompt( - ctx context.Context, - scope integrationruntime.PromptScope, - prompt []openai.ChatCompletionMessageParamUnion, -) []openai.ChatCompletionMessageParamUnion { - if r == nil { - return prompt - } - out := prompt - for _, integration := range r.items { - out = integration.AugmentPrompt(ctx, scope, out) - } - return out -} - type commandIntegrationRegistration struct { integration integrationruntime.CommandIntegration definition integrationruntime.CommandDefinition @@ -266,14 +225,6 @@ func (oc *AIClient) toolScope(portal *bridgev2.Portal, meta *PortalMetadata) int } } -func (oc *AIClient) promptScope(portal *bridgev2.Portal, meta *PortalMetadata) integrationruntime.PromptScope { - return integrationruntime.PromptScope{ - Client: oc, - Portal: portal, - Meta: meta, - } -} - func (oc *AIClient) commandScope(portal *bridgev2.Portal, meta *PortalMetadata, evt any) integrationruntime.CommandScope { return integrationruntime.CommandScope{ Client: oc, @@ -288,7 +239,6 @@ func (oc *AIClient) initIntegrations() { return } oc.toolRegistry = &toolIntegrationRegistry{} - oc.promptRegistry = &promptIntegrationRegistry{} oc.commandRegistry = newCommandIntegrationRegistry() oc.eventRegistry = &eventIntegrationRegistry{} oc.purgeRegistry = &purgeIntegrationRegistry{} @@ -307,9 +257,6 @@ func (oc *AIClient) initIntegrations() { if toolIntegration, ok := module.(integrationruntime.ToolIntegration); ok { oc.toolRegistry.register(toolIntegration) } - if promptIntegration, ok := module.(integrationruntime.PromptIntegration); ok { - oc.promptRegistry.register(promptIntegration) - } if commandIntegration, ok := module.(integrationruntime.CommandIntegration); ok { defs := commandIntegration.CommandDefinitions(context.Background(), oc.commandScope(nil, nil, nil)) oc.commandRegistry.register(commandIntegration, defs) @@ -325,11 +272,9 @@ func (oc *AIClient) initIntegrations() { } } - // Register core integrations after modules so module tool/prompt implementations take precedence. + // Register core integrations after modules so module tool implementations take precedence. coreTools := &coreToolIntegration{client: oc} - corePrompts := &corePromptIntegration{client: oc} oc.toolRegistry.register(coreTools) - oc.promptRegistry.register(corePrompts) registerModuleCommands(oc.commandRegistry.definitions()) } @@ -523,35 +468,6 @@ func (oc *AIClient) executeIntegratedTool( }) } -func (oc *AIClient) additionalSystemMessages( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, -) []openai.ChatCompletionMessageParamUnion { - if oc == nil { - return nil - } - if oc.promptRegistry == nil { - return oc.buildAdditionalSystemPromptsCore(ctx, portal, meta) - } - return oc.promptRegistry.additionalMessages(ctx, oc.promptScope(portal, meta)) -} - -func (oc *AIClient) augmentPromptWithIntegrations( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, -) []openai.ChatCompletionMessageParamUnion { - if oc == nil { - return prompt - } - if oc.promptRegistry == nil { - return prompt - } - return oc.promptRegistry.augmentPrompt(ctx, oc.promptScope(portal, meta), prompt) -} - func (oc *AIClient) executeIntegratedCommand( ctx context.Context, portal *bridgev2.Portal, @@ -711,29 +627,3 @@ func (c *coreToolIntegration) ToolAvailability( ) (bool, bool, integrationruntime.SettingSource, string) { return false, false, integrationruntime.SourceGlobalDefault, "" } - -type corePromptIntegration struct { - client *AIClient -} - -func (c *corePromptIntegration) Name() string { return "core" } - -func (c *corePromptIntegration) AdditionalSystemMessages( - ctx context.Context, - scope integrationruntime.PromptScope, -) []openai.ChatCompletionMessageParamUnion { - if c == nil || c.client == nil { - return nil - } - portal, _ := scope.Portal.(*bridgev2.Portal) - meta, _ := scope.Meta.(*PortalMetadata) - return c.client.buildAdditionalSystemPromptsCore(ctx, portal, meta) -} - -func (c *corePromptIntegration) AugmentPrompt( - _ context.Context, - _ integrationruntime.PromptScope, - prompt []openai.ChatCompletionMessageParamUnion, -) []openai.ChatCompletionMessageParamUnion { - return prompt -} diff --git a/bridges/ai/integrations_test.go b/bridges/ai/integrations_test.go index 517def4f..b970a6f5 100644 --- a/bridges/ai/integrations_test.go +++ b/bridges/ai/integrations_test.go @@ -3,11 +3,8 @@ package ai import ( "context" "reflect" - "slices" "testing" - "github.com/openai/openai-go/v3" - integrationruntime "github.com/beeper/agentremote/pkg/integrations/runtime" ) @@ -30,28 +27,6 @@ func (f fakeToolIntegration) ToolAvailability(_ context.Context, _ integrationru return false, false, integrationruntime.SourceGlobalDefault, "" } -type fakePromptIntegration struct { - name string - tag string -} - -func (f fakePromptIntegration) Name() string { return f.name } - -func (f fakePromptIntegration) AdditionalSystemMessages(_ context.Context, _ integrationruntime.PromptScope) []openai.ChatCompletionMessageParamUnion { - return []openai.ChatCompletionMessageParamUnion{openai.SystemMessage("sys:" + f.tag)} -} - -func (f fakePromptIntegration) AugmentPrompt( - _ context.Context, - _ integrationruntime.PromptScope, - prompt []openai.ChatCompletionMessageParamUnion, -) []openai.ChatCompletionMessageParamUnion { - out := make([]openai.ChatCompletionMessageParamUnion, 0, len(prompt)+1) - out = append(out, prompt...) - out = append(out, openai.UserMessage("aug:"+f.tag)) - return out -} - func TestToolIntegrationRegistryDefinitionsDeterministic(t *testing.T) { reg := &toolIntegrationRegistry{} reg.register(fakeToolIntegration{name: "one", defs: []integrationruntime.ToolDefinition{{Name: "a"}, {Name: "b"}}}) @@ -68,41 +43,6 @@ func TestToolIntegrationRegistryDefinitionsDeterministic(t *testing.T) { } } -func TestPromptIntegrationRegistryOrder(t *testing.T) { - reg := &promptIntegrationRegistry{} - reg.register(fakePromptIntegration{name: "one", tag: "1"}) - reg.register(fakePromptIntegration{name: "two", tag: "2"}) - - sys := reg.additionalMessages(context.Background(), integrationruntime.PromptScope{}) - if len(sys) != 2 { - t.Fatalf("expected 2 system messages, got %d", len(sys)) - } - - base := []openai.ChatCompletionMessageParamUnion{openai.UserMessage("base")} - out := reg.augmentPrompt(context.Background(), integrationruntime.PromptScope{}, base) - if len(out) != 3 { - t.Fatalf("expected augmented prompt len=3, got %d", len(out)) - } -} - -func TestPromptIntegrationRegistryAugmentPromptIdempotent(t *testing.T) { - reg := &promptIntegrationRegistry{} - reg.register(fakePromptIntegration{name: "one", tag: "1"}) - reg.register(fakePromptIntegration{name: "two", tag: "2"}) - - base := []openai.ChatCompletionMessageParamUnion{openai.UserMessage("base")} - baseCopy := slices.Clone(base) - - outA := reg.augmentPrompt(context.Background(), integrationruntime.PromptScope{}, base) - outB := reg.augmentPrompt(context.Background(), integrationruntime.PromptScope{}, base) - if !reflect.DeepEqual(outA, outB) { - t.Fatalf("augmentPrompt should be deterministic/idempotent; got outA=%v outB=%v", outA, outB) - } - if !reflect.DeepEqual(base, baseCopy) { - t.Fatalf("augmentPrompt mutated input prompt; got=%v want=%v", base, baseCopy) - } -} - type fakeLifecycleIntegration struct { startCount int stopCount int diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index 000ce4c7..cee46466 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -6,7 +6,6 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/id" @@ -324,21 +323,6 @@ func (oc *AIClient) getSteeringMessages(roomID id.RoomID) []string { return messages } -func buildSteeringUserMessages(prompts []string) []openai.ChatCompletionMessageParamUnion { - if len(prompts) == 0 { - return nil - } - messages := make([]openai.ChatCompletionMessageParamUnion, 0, len(prompts)) - for _, prompt := range prompts { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - continue - } - messages = append(messages, openai.UserMessage(prompt)) - } - return messages -} - func buildSteeringPromptMessages(prompts []string) []PromptMessage { if len(prompts) == 0 { return nil diff --git a/bridges/ai/prompt_context_ops.go b/bridges/ai/prompt_context_ops.go index 57a1450a..dca40eb1 100644 --- a/bridges/ai/prompt_context_ops.go +++ b/bridges/ai/prompt_context_ops.go @@ -50,4 +50,3 @@ func NewUserTextPromptMessage(text string) PromptMessage { }}, } } - diff --git a/bridges/ai/response_retry.go b/bridges/ai/response_retry.go index 58cf8ef0..9a0e05dd 100644 --- a/bridges/ai/response_retry.go +++ b/bridges/ai/response_retry.go @@ -367,7 +367,7 @@ func (oc *AIClient) runAgentLoopWithRetry( func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext PromptContext) (responseFuncCanonical, string) { if HasUnsupportedResponsesPromptContext(promptContext) { - return oc.runChatCompletionsAgentLoop, "chat_completions" + return oc.runChatCompletionsAgentLoopPrompt, "chat_completions" } modelID := "" if oc != nil { @@ -380,9 +380,9 @@ func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext P return false, nil, fmt.Errorf("invalid model configuration: direct OpenAI model %q cannot use chat_completions", modelID) }, "invalid_model_api" } - return oc.runChatCompletionsAgentLoop, "chat_completions" + return oc.runChatCompletionsAgentLoopPrompt, "chat_completions" default: - return oc.runResponsesAgentLoop, "responses" + return oc.runResponsesAgentLoopPrompt, "responses" } } diff --git a/bridges/ai/streaming_chat_completions.go b/bridges/ai/streaming_chat_completions.go index 42032df2..dbf81747 100644 --- a/bridges/ai/streaming_chat_completions.go +++ b/bridges/ai/streaming_chat_completions.go @@ -216,16 +216,6 @@ func (a *chatCompletionsTurnAdapter) FinalizeAgentLoop(ctx context.Context) { } -func (oc *AIClient) runChatCompletionsAgentLoop( - ctx context.Context, - evt *event.Event, - portal *bridgev2.Portal, - meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, -) (bool, *ContextLengthError, error) { - return oc.runChatCompletionsAgentLoopPrompt(ctx, evt, portal, meta, ChatMessagesToPromptContext(messages)) -} - func (oc *AIClient) runChatCompletionsAgentLoopPrompt( ctx context.Context, evt *event.Event, diff --git a/bridges/ai/streaming_responses_api.go b/bridges/ai/streaming_responses_api.go index 566d4e2c..0dd014ff 100644 --- a/bridges/ai/streaming_responses_api.go +++ b/bridges/ai/streaming_responses_api.go @@ -464,16 +464,6 @@ func (oc *AIClient) handleProviderToolCompleted( } // runResponsesAgentLoop handles the Responses API provider adapter under the canonical agent loop. -func (oc *AIClient) runResponsesAgentLoop( - ctx context.Context, - evt *event.Event, - portal *bridgev2.Portal, - meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, -) (bool, *ContextLengthError, error) { - return oc.runResponsesAgentLoopPrompt(ctx, evt, portal, meta, ChatMessagesToPromptContext(messages)) -} - func (oc *AIClient) runResponsesAgentLoopPrompt( ctx context.Context, evt *event.Event, diff --git a/bridges/ai/subagent_spawn.go b/bridges/ai/subagent_spawn.go index 241efbcf..3fa98c0c 100644 --- a/bridges/ai/subagent_spawn.go +++ b/bridges/ai/subagent_spawn.go @@ -9,7 +9,6 @@ import ( "time" "github.com/google/uuid" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" diff --git a/pkg/integrations/memory/integration.go b/pkg/integrations/memory/integration.go index b7357bb0..fddfdf7f 100644 --- a/pkg/integrations/memory/integration.go +++ b/pkg/integrations/memory/integration.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/openai/openai-go/v3" "go.mau.fi/util/dbutil" "github.com/beeper/agentremote/pkg/agents" From ac2d86926bf4e5436839a47505faf077b7c6f93d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 22:09:47 +0200 Subject: [PATCH 05/23] Remove unused OpenAI imports; replace slices import Replace the unused "slices" import with "strings" in response_retry.go and remove unused OpenAI imports from streaming_chat_completions.go and streaming_executor.go. These changes clean up imports to fix unused-import warnings and do not change runtime behavior. --- bridges/ai/response_retry.go | 2 +- bridges/ai/streaming_chat_completions.go | 1 - bridges/ai/streaming_executor.go | 1 - docs/prompt-flow.html | 178 +++++++++++++++++++++++ 4 files changed, 179 insertions(+), 3 deletions(-) create mode 100644 docs/prompt-flow.html diff --git a/bridges/ai/response_retry.go b/bridges/ai/response_retry.go index 9a0e05dd..305bf8a4 100644 --- a/bridges/ai/response_retry.go +++ b/bridges/ai/response_retry.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "math" - "slices" + "strings" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" diff --git a/bridges/ai/streaming_chat_completions.go b/bridges/ai/streaming_chat_completions.go index dbf81747..4d38105f 100644 --- a/bridges/ai/streaming_chat_completions.go +++ b/bridges/ai/streaming_chat_completions.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/packages/param" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" diff --git a/bridges/ai/streaming_executor.go b/bridges/ai/streaming_executor.go index fb9de84a..9d20ba86 100644 --- a/bridges/ai/streaming_executor.go +++ b/bridges/ai/streaming_executor.go @@ -3,7 +3,6 @@ package ai import ( "context" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" diff --git a/docs/prompt-flow.html b/docs/prompt-flow.html new file mode 100644 index 00000000..4f5dc9b6 --- /dev/null +++ b/docs/prompt-flow.html @@ -0,0 +1,178 @@ + + + + + + AI Bridge Prompt Flow + + + +
+

AI Bridge Prompt Flow

+

Final architecture after the prompt-pipeline cleanup. Internal state is canonical PromptContext. Provider payloads are created only at the API edge.

+ +
+
End To End
+

Runtime Flow

+
Inbound event
+  |
+  v
+Attachment + link preprocessing
+  - images stay native
+  - files/PDFs become text context
+  - audio/video must become transcript/description first
+  |
+  v
+Canonical prompt assembly
+  - system prompt text
+  - replayed canonical history
+  - current canonical user turn
+  |
+  v
+Canonical agent loop
+  - PromptContext in
+  - PromptMessage follow-ups
+  - retry/overflow works on PromptContext
+  |
+  v
+Provider-edge serialization
+  - Chat Completions payload
+  - Responses payload
+  |
+  v
+Streaming persistence + canonical turn storage
+
+ +
+
System Prompt
+

System Prompt Order

+
1. session greeting (when included)
+2. base agent prompt or default system prompt
+3. bridge core fragments
+   - group intro
+   - desktop account hint
+   - sessionKey hint
+4. bridge-local memory prompt context
+5. trusted inbound metadata
+

There is no shared prompt-integration registry anymore. Memory prompt augmentation is bridge-local.

+
+ +
+
Canonical Model
+

Internal Prompt Types

+
PromptContext
+  systemPrompt: string
+  messages: PromptMessage[]
+
+PromptMessage roles
+  - user
+  - assistant
+  - tool_result
+
+PromptBlock types
+  - text
+  - image
+  - thinking
+  - tool_call
+

There are no native file, audio, or video prompt blocks in the bridge runtime anymore.

+
+ +
+
What Survived
+

Intentional Runtime Seams

+
- one preprocessing pipeline
+- one prompt assembler
+- one canonical history replay path
+- one canonical continuation path
+- one canonical overflow/compaction path
+- two provider loops
+  - chat completions
+  - responses
+
+ +
+
Removed
+

Deleted Legacy Paths

+
- shared PromptIntegration abstraction
+- promptRegistry / additionalSystemMessages / augmentPromptWithIntegrations
+- promptContextToDispatchMessages
+- split steering follow-up chat-message builders
+- chat-message-union loop state as the internal source of truth
+- legacy config roots:
+  - providers
+  - pruning
+  - link_previews
+  - tools.search
+  - tools.fetch
+
+
+ + From ba9fe0e643f9c435909b0332e53273ba3653cbbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 22:13:11 +0200 Subject: [PATCH 06/23] Delete prompt-flow.html --- docs/prompt-flow.html | 178 ------------------------------------------ 1 file changed, 178 deletions(-) delete mode 100644 docs/prompt-flow.html diff --git a/docs/prompt-flow.html b/docs/prompt-flow.html deleted file mode 100644 index 4f5dc9b6..00000000 --- a/docs/prompt-flow.html +++ /dev/null @@ -1,178 +0,0 @@ - - - - - - AI Bridge Prompt Flow - - - -
-

AI Bridge Prompt Flow

-

Final architecture after the prompt-pipeline cleanup. Internal state is canonical PromptContext. Provider payloads are created only at the API edge.

- -
-
End To End
-

Runtime Flow

-
Inbound event
-  |
-  v
-Attachment + link preprocessing
-  - images stay native
-  - files/PDFs become text context
-  - audio/video must become transcript/description first
-  |
-  v
-Canonical prompt assembly
-  - system prompt text
-  - replayed canonical history
-  - current canonical user turn
-  |
-  v
-Canonical agent loop
-  - PromptContext in
-  - PromptMessage follow-ups
-  - retry/overflow works on PromptContext
-  |
-  v
-Provider-edge serialization
-  - Chat Completions payload
-  - Responses payload
-  |
-  v
-Streaming persistence + canonical turn storage
-
- -
-
System Prompt
-

System Prompt Order

-
1. session greeting (when included)
-2. base agent prompt or default system prompt
-3. bridge core fragments
-   - group intro
-   - desktop account hint
-   - sessionKey hint
-4. bridge-local memory prompt context
-5. trusted inbound metadata
-

There is no shared prompt-integration registry anymore. Memory prompt augmentation is bridge-local.

-
- -
-
Canonical Model
-

Internal Prompt Types

-
PromptContext
-  systemPrompt: string
-  messages: PromptMessage[]
-
-PromptMessage roles
-  - user
-  - assistant
-  - tool_result
-
-PromptBlock types
-  - text
-  - image
-  - thinking
-  - tool_call
-

There are no native file, audio, or video prompt blocks in the bridge runtime anymore.

-
- -
-
What Survived
-

Intentional Runtime Seams

-
- one preprocessing pipeline
-- one prompt assembler
-- one canonical history replay path
-- one canonical continuation path
-- one canonical overflow/compaction path
-- two provider loops
-  - chat completions
-  - responses
-
- -
-
Removed
-

Deleted Legacy Paths

-
- shared PromptIntegration abstraction
-- promptRegistry / additionalSystemMessages / augmentPromptWithIntegrations
-- promptContextToDispatchMessages
-- split steering follow-up chat-message builders
-- chat-message-union loop state as the internal source of truth
-- legacy config roots:
-  - providers
-  - pruning
-  - link_previews
-  - tools.search
-  - tools.fetch
-
-
- - From 63fb13a2e06d58399615da061f1d593f04e5fc65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 22:13:34 +0200 Subject: [PATCH 07/23] Resolve default command prefixes and add tests Use bridgesdk.ResolveCommandPrefix to resolve the configured command prefix with a hardcoded fallback in Codex and DummyBridge connectors (fallbacks: "!ai" and "!dummybridge"). Add tests (TestGetNameUsesDefaultCommandPrefixBeforeStartup) for both bridges to assert the DefaultCommandPrefix value before startup. Files changed: bridges/codex/constructors.go, bridges/dummybridge/connector.go, and tests in bridges/codex/connector_test.go and bridges/dummybridge/connector_test.go. --- bridges/codex/connector_test.go | 7 +++++++ bridges/codex/constructors.go | 2 +- bridges/dummybridge/connector.go | 2 +- bridges/dummybridge/connector_test.go | 7 +++++++ 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/bridges/codex/connector_test.go b/bridges/codex/connector_test.go index c6fd518d..f3240237 100644 --- a/bridges/codex/connector_test.go +++ b/bridges/codex/connector_test.go @@ -38,6 +38,13 @@ func TestGetCapabilitiesEnablesContactListProvisioning(t *testing.T) { } } +func TestGetNameUsesDefaultCommandPrefixBeforeStartup(t *testing.T) { + conn := NewConnector() + if got := conn.GetName().DefaultCommandPrefix; got != "!ai" { + t.Fatalf("expected default command prefix !ai, got %q", got) + } +} + func TestHostAuthLoginIDUsesDedicatedPrefix(t *testing.T) { conn := NewConnector() mxid := id.UserID("@alice:example.com") diff --git a/bridges/codex/constructors.go b/bridges/codex/constructors.go index f36d67df..d73d89ae 100644 --- a/bridges/codex/constructors.go +++ b/bridges/codex/constructors.go @@ -65,7 +65,7 @@ func NewConnector() *CodexConnector { BeeperBridgeType: "codex", DefaultPort: 29346, DefaultCommandPrefix: func() string { - return cc.Config.Bridge.CommandPrefix + return bridgesdk.ResolveCommandPrefix(cc.Config.Bridge.CommandPrefix, "!ai") }, FillBridgeInfo: func(portal *bridgev2.Portal, content *event.BridgeEventContent) { if portal == nil { diff --git a/bridges/dummybridge/connector.go b/bridges/dummybridge/connector.go index b9338554..d32adad9 100644 --- a/bridges/dummybridge/connector.go +++ b/bridges/dummybridge/connector.go @@ -52,7 +52,7 @@ func NewConnector() *DummyBridgeConnector { BeeperBridgeType: "dummybridge", DefaultPort: 29349, DefaultCommandPrefix: func() string { - return dc.Config.Bridge.CommandPrefix + return bridgesdk.ResolveCommandPrefix(dc.Config.Bridge.CommandPrefix, "!dummybridge") }, ExampleConfig: exampleNetworkConfig, ConfigData: &dc.Config, diff --git a/bridges/dummybridge/connector_test.go b/bridges/dummybridge/connector_test.go index 336ba315..263f60e4 100644 --- a/bridges/dummybridge/connector_test.go +++ b/bridges/dummybridge/connector_test.go @@ -38,3 +38,10 @@ func TestGetCapabilitiesExposeProvisioningSearchAndContacts(t *testing.T) { t.Fatal("expected search provisioning to be enabled") } } + +func TestGetNameUsesDefaultCommandPrefixBeforeStartup(t *testing.T) { + conn := NewConnector() + if got := conn.GetName().DefaultCommandPrefix; got != "!dummybridge" { + t.Fatalf("expected default command prefix !dummybridge, got %q", got) + } +} From 66eeb4162ec9ff7686348f6a2b615570bd7e951d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 22:25:53 +0200 Subject: [PATCH 08/23] Add LoginCredentials; rename prompt helpers Introduce LoginCredentials to hold per-login api_key/base_url/service_tokens and migrate UserLoginMetadata to use Credentials. Update accessors, merging/clone helpers, and service token handling across MCP, image generation, media, provisioning, and login code. Rename several prompt conversion and helper functions to unexported variants (e.g. PromptContextToResponsesInput -> promptContextToResponsesInput, ChatMessagesToPromptContext -> chatMessagesToPromptContext) and update all call sites, including provider, streaming, compaction, and logging changes to use PromptContext. Add a test to ensure legacy config sections are removed from example config and remove those sections from the bridge example YAML. Misc: adapt error logging to accept PromptContext, adjust continuation/steering logic to append steering messages into the canonical prompt, and update related tests. --- bridges/ai/agent_loop_steering_test.go | 7 +- bridges/ai/chat.go | 2 +- bridges/ai/client.go | 2 +- bridges/ai/compaction_summarization.go | 6 +- bridges/ai/desktop_api_sessions.go | 8 +- bridges/ai/error_logging.go | 20 +- bridges/ai/image_generation_tool.go | 14 +- .../image_generation_tool_magic_proxy_test.go | 36 +- bridges/ai/integrations_example-config.yaml | 23 -- .../ai/integrations_example_config_test.go | 68 ++++ bridges/ai/login.go | 43 +-- bridges/ai/login_loaders.go | 4 +- bridges/ai/login_loaders_test.go | 10 +- bridges/ai/magic_proxy_test.go | 16 +- bridges/ai/mcp_helpers.go | 30 +- bridges/ai/mcp_servers.go | 8 +- bridges/ai/mcp_servers_test.go | 2 +- .../media_understanding_runner_openai_test.go | 12 +- bridges/ai/messages_responses_input_test.go | 2 +- bridges/ai/metadata.go | 156 +++++++- bridges/ai/model_catalog_test.go | 4 +- bridges/ai/prompt_context_local.go | 8 +- bridges/ai/provider_openai_chat.go | 2 +- bridges/ai/provider_openai_responses.go | 6 +- bridges/ai/provisioning.go | 19 +- bridges/ai/response_retry.go | 8 +- bridges/ai/runtime_compaction_adapter.go | 2 +- bridges/ai/streaming_chat_completions.go | 6 +- bridges/ai/streaming_continuation.go | 24 +- bridges/ai/streaming_executor.go | 3 +- bridges/ai/streaming_init.go | 8 +- bridges/ai/streaming_init_test.go | 10 +- bridges/ai/streaming_input_conversion.go | 30 +- bridges/ai/streaming_responses_api.go | 22 +- bridges/ai/streaming_responses_input_test.go | 4 +- bridges/ai/streaming_state.go | 2 - bridges/ai/token_resolver.go | 60 ++- bridges/ai/tools.go | 2 +- bridges/ai/tools_search_fetch.go | 2 +- bridges/ai/tools_search_fetch_test.go | 10 +- bridges/ai/tools_tts_test.go | 8 +- .../integrations_example-config.yaml | 364 ------------------ 42 files changed, 435 insertions(+), 638 deletions(-) create mode 100644 bridges/ai/integrations_example_config_test.go delete mode 100644 pkg/connector/integrations_example-config.yaml diff --git a/bridges/ai/agent_loop_steering_test.go b/bridges/ai/agent_loop_steering_test.go index 703f6a87..e3be7ee7 100644 --- a/bridges/ai/agent_loop_steering_test.go +++ b/bridges/ai/agent_loop_steering_test.go @@ -263,16 +263,17 @@ func TestBuildContinuationParams_UsesPendingSteeringPromptsBeforeDrainingQueue(t } state := &streamingState{roomID: roomID} state.addPendingSteeringPrompts([]string{"pending steer"}) + prompt := PromptContext{} - params := oc.buildContinuationParams(context.Background(), state, nil, nil, nil) + params := oc.buildContinuationParams(context.Background(), &prompt, state, nil, nil, nil) if len(params.Input.OfInputItemList) == 0 { t.Fatal("expected continuation input to include stored steering prompt") } if pending := state.consumePendingSteeringPrompts(); len(pending) != 0 { t.Fatalf("expected pending steering prompts to be consumed, got %#v", pending) } - if len(state.baseInput) == 0 { - t.Fatal("expected steering input to persist in base input even when history starts empty") + if len(prompt.Messages) == 0 { + t.Fatal("expected steering input to persist in canonical prompt even when history starts empty") } if snapshot := oc.getRoomRun(roomID); snapshot == nil || len(snapshot.steerQueue) != 1 { t.Fatalf("expected queued steering item to remain available, got %#v", snapshot) diff --git a/bridges/ai/chat.go b/bridges/ai/chat.go index 46c85a65..caf7884a 100644 --- a/bridges/ai/chat.go +++ b/bridges/ai/chat.go @@ -122,7 +122,7 @@ func (oc *AIClient) canUseImageGeneration() bool { return false } loginMeta := loginMetadata(oc.UserLogin) - if loginMeta == nil || loginMeta.APIKey == "" { + if loginMeta == nil || strings.TrimSpace(oc.connector.resolveProviderAPIKey(loginMeta)) == "" { return false } switch loginMeta.Provider { diff --git a/bridges/ai/client.go b/bridges/ai/client.go index 70008d7b..4092881c 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -487,7 +487,7 @@ func initProviderForLogin(key string, meta *UserLoginMetadata, connector *OpenAI return initOpenRouterProvider(key, connector.resolveOpenRouterBaseURL(), "", connector.defaultPDFEngineForInit(), ProviderOpenRouter, log) case ProviderMagicProxy: - baseURL := normalizeProxyBaseURL(meta.BaseURL) + baseURL := normalizeProxyBaseURL(loginCredentialBaseURL(meta)) if baseURL == "" { return nil, errors.New("magic proxy base_url is required") } diff --git a/bridges/ai/compaction_summarization.go b/bridges/ai/compaction_summarization.go index 032ab8b2..3883edd1 100644 --- a/bridges/ai/compaction_summarization.go +++ b/bridges/ai/compaction_summarization.go @@ -619,8 +619,8 @@ func (oc *AIClient) applyCompactionModelSummaryAndRefresh( decision airuntime.CompactionDecision, contextWindowTokens int, ) PromptContext { - originalMessages := PromptContextToChatCompletionMessages(originalPrompt, false) - compactedMessages := PromptContextToChatCompletionMessages(compactedPrompt, false) + originalMessages := promptContextToChatCompletionMessages(originalPrompt, false) + compactedMessages := promptContextToChatCompletionMessages(compactedPrompt, false) out := compactedMessages if oc.pruningSummarizationEnabled() { dropped := selectDroppedCompactionMessages(originalMessages, compactedMessages, decision.DroppedCount) @@ -655,5 +655,5 @@ func (oc *AIClient) applyCompactionModelSummaryAndRefresh( if refresh := strings.TrimSpace(oc.pruningPostCompactionRefreshPrompt()); refresh != "" { out = injectSystemPromptAtFirstNonSystem(out, refresh) } - return ChatMessagesToPromptContext(out) + return chatMessagesToPromptContext(out) } diff --git a/bridges/ai/desktop_api_sessions.go b/bridges/ai/desktop_api_sessions.go index 05d4123d..2911d6a1 100644 --- a/bridges/ai/desktop_api_sessions.go +++ b/bridges/ai/desktop_api_sessions.go @@ -158,11 +158,11 @@ func (oc *AIClient) desktopAPIInstances() map[string]DesktopAPIInstance { if oc == nil || oc.UserLogin == nil { return instances } - meta := loginMetadata(oc.UserLogin) - if meta == nil || meta.ServiceTokens == nil { + creds := loginCredentials(loginMetadata(oc.UserLogin)) + if creds == nil || creds.ServiceTokens == nil { return instances } - for name, instance := range meta.ServiceTokens.DesktopAPIInstances { + for name, instance := range creds.ServiceTokens.DesktopAPIInstances { key := normalizeDesktopInstanceName(name) if key == "" { continue @@ -172,7 +172,7 @@ func (oc *AIClient) desktopAPIInstances() map[string]DesktopAPIInstance { } instances[key] = instance } - if token := strings.TrimSpace(meta.ServiceTokens.DesktopAPI); token != "" { + if token := strings.TrimSpace(creds.ServiceTokens.DesktopAPI); token != "" { if _, ok := instances[desktopDefaultInstance]; !ok { instances[desktopDefaultInstance] = DesktopAPIInstance{Token: token} } diff --git a/bridges/ai/error_logging.go b/bridges/ai/error_logging.go index add53183..6307918e 100644 --- a/bridges/ai/error_logging.go +++ b/bridges/ai/error_logging.go @@ -9,14 +9,14 @@ import ( "github.com/rs/zerolog" ) -func logResponsesFailure(log zerolog.Logger, err error, params responses.ResponseNewParams, meta *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion, stage string) { - logProviderFailure(log, err, meta, messages, stage, "Responses API failure", func(event *zerolog.Event) { +func logResponsesFailure(log zerolog.Logger, err error, params responses.ResponseNewParams, meta *PortalMetadata, prompt PromptContext, stage string) { + logProviderFailure(log, err, meta, prompt, stage, "Responses API failure", func(event *zerolog.Event) { addResponsesParamsSummary(event, params) }) } -func logChatCompletionsFailure(log zerolog.Logger, err error, params openai.ChatCompletionNewParams, meta *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion, stage string) { - logProviderFailure(log, err, meta, messages, stage, "Chat Completions failure", func(event *zerolog.Event) { +func logChatCompletionsFailure(log zerolog.Logger, err error, params openai.ChatCompletionNewParams, meta *PortalMetadata, prompt PromptContext, stage string) { + logProviderFailure(log, err, meta, prompt, stage, "Chat Completions failure", func(event *zerolog.Event) { addChatParamsSummary(event, params) }) } @@ -25,13 +25,13 @@ func logProviderFailure( log zerolog.Logger, err error, meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, stage string, msg string, addSummary func(*zerolog.Event), ) { event := log.Error().Err(err).Str("stage", stage) - addRequestSummary(event, meta, messages) + addRequestSummary(event, meta, prompt) if addSummary != nil { addSummary(event) } @@ -39,7 +39,7 @@ func logProviderFailure( event.Msg(msg) } -func addRequestSummary(event *zerolog.Event, metadata *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion) { +func addRequestSummary(event *zerolog.Event, metadata *PortalMetadata, prompt PromptContext) { if event == nil { return } @@ -54,9 +54,9 @@ func addRequestSummary(event *zerolog.Event, metadata *PortalMetadata, messages event.Str("runtime_model_override", metadata.RuntimeModelOverride) } } - event.Int("message_count", len(messages)) - event.Bool("has_audio", hasAudioContent(messages)) - event.Bool("has_multimodal", hasMultimodalContent(messages)) + event.Int("message_count", len(prompt.Messages)) + event.Bool("has_audio", promptHasAudioContent(prompt)) + event.Bool("has_multimodal", promptHasMultimodalContent(prompt)) } func addResponsesParamsSummary(event *zerolog.Event, params responses.ResponseNewParams) { diff --git a/bridges/ai/image_generation_tool.go b/bridges/ai/image_generation_tool.go index 5afa2d1e..c72423d9 100644 --- a/bridges/ai/image_generation_tool.go +++ b/bridges/ai/image_generation_tool.go @@ -232,7 +232,7 @@ func supportsOpenAIImageGen(btc *BridgeToolContext) bool { case ProviderOpenAI, ProviderMagicProxy: if loginMeta.Provider == ProviderMagicProxy { // Magic Proxy uses a per-login token+base URL, not the OpenAI config key. - return strings.TrimSpace(loginMeta.APIKey) != "" && strings.TrimSpace(loginMeta.BaseURL) != "" + return loginCredentialAPIKey(loginMeta) != "" && loginCredentialBaseURL(loginMeta) != "" } return btc.Client.connector.resolveOpenAIAPIKey(loginMeta) != "" default: @@ -261,8 +261,8 @@ func supportsGeminiImageGen(btc *BridgeToolContext) bool { return strings.TrimSpace(svc.BaseURL) != "" && strings.TrimSpace(svc.APIKey) != "" } } - base := normalizeProxyBaseURL(loginMeta.BaseURL) - return base != "" && strings.TrimSpace(loginMeta.APIKey) != "" + base := normalizeProxyBaseURL(loginCredentialBaseURL(loginMeta)) + return base != "" && loginCredentialAPIKey(loginMeta) != "" default: return false } @@ -456,7 +456,7 @@ func buildOpenAIImagesBaseURL(btc *BridgeToolContext) (string, error) { return strings.TrimSuffix(strings.TrimSpace(svc.BaseURL), "/"), nil } } - base := normalizeProxyBaseURL(loginMeta.BaseURL) + base := normalizeProxyBaseURL(loginCredentialBaseURL(loginMeta)) if base == "" { return "", errors.New("magic proxy base_url is required for image generation") } @@ -482,7 +482,7 @@ func buildGeminiBaseURL(btc *BridgeToolContext) (string, error) { return strings.TrimSuffix(strings.TrimSpace(svc.BaseURL), "/"), nil } } - base := normalizeProxyBaseURL(loginMeta.BaseURL) + base := normalizeProxyBaseURL(loginCredentialBaseURL(loginMeta)) if base == "" { return "", errors.New("magic proxy base_url is required for image generation") } @@ -592,8 +592,8 @@ func resolveOpenRouterImageGenEndpoint(btc *BridgeToolContext) (baseURL string, // Provider-specific per-login endpoints. switch meta.Provider { case ProviderMagicProxy: - base := normalizeProxyBaseURL(meta.BaseURL) - key := trim(meta.APIKey) + base := normalizeProxyBaseURL(loginCredentialBaseURL(meta)) + key := trim(loginCredentialAPIKey(meta)) if base == "" || key == "" { return "", "", false } diff --git a/bridges/ai/image_generation_tool_magic_proxy_test.go b/bridges/ai/image_generation_tool_magic_proxy_test.go index a130f3db..e425df0a 100644 --- a/bridges/ai/image_generation_tool_magic_proxy_test.go +++ b/bridges/ai/image_generation_tool_magic_proxy_test.go @@ -5,8 +5,10 @@ import "testing" func TestResolveImageGenProviderMagicProxyPrefersOpenRouterForSimplePrompts(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -25,8 +27,10 @@ func TestResolveImageGenProviderMagicProxyPrefersOpenRouterForSimplePrompts(t *t func TestResolveImageGenProviderMagicProxyStillPrefersOpenRouterWhenCountIsGreaterThanOne(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -45,8 +49,10 @@ func TestResolveImageGenProviderMagicProxyStillPrefersOpenRouterWhenCountIsGreat func TestResolveImageGenProviderMagicProxyProviderOpenAIStillRoutesToOpenRouter(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -66,8 +72,10 @@ func TestResolveImageGenProviderMagicProxyProviderOpenAIStillRoutesToOpenRouter( func TestResolveImageGenProviderMagicProxyProviderGeminiUsesGemini(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -87,8 +95,10 @@ func TestResolveImageGenProviderMagicProxyProviderGeminiUsesGemini(t *testing.T) func TestBuildOpenAIImagesBaseURLMagicProxy(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -104,8 +114,10 @@ func TestBuildOpenAIImagesBaseURLMagicProxy(t *testing.T) { func TestBuildGeminiBaseURLMagicProxy(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) diff --git a/bridges/ai/integrations_example-config.yaml b/bridges/ai/integrations_example-config.yaml index f51aa046..3bed8a64 100644 --- a/bridges/ai/integrations_example-config.yaml +++ b/bridges/ai/integrations_example-config.yaml @@ -124,28 +124,6 @@ tools: models: - provider: "openrouter" model: "google/gemini-3-flash-preview" - chunking: - tokens: 400 - overlap: 80 - sync: - on_session_start: true - on_search: true - watch: true - watch_debounce_ms: 1500 - interval_minutes: 0 - sessions: - delta_bytes: 100000 - delta_messages: 50 - query: - max_results: 6 - min_score: 0.35 - hybrid: - candidate_multiplier: 4 - cache: - enabled: true - max_entries: -1 - experimental: - session_memory: false agents: defaults: @@ -190,4 +168,3 @@ agents: soft_threshold_tokens: 4000 prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." - diff --git a/bridges/ai/integrations_example_config_test.go b/bridges/ai/integrations_example_config_test.go new file mode 100644 index 00000000..a5ca1a1b --- /dev/null +++ b/bridges/ai/integrations_example_config_test.go @@ -0,0 +1,68 @@ +package ai + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "gopkg.in/yaml.v3" +) + +func TestExampleConfigFilesExcludeLegacyConfigSections(t *testing.T) { + legacyKeys := map[string]struct{}{ + "chunking": {}, + "sync": {}, + "query": {}, + "cache": {}, + "experimental": {}, + "pruning": {}, + "recall": {}, + } + + t.Run("bridge example", func(t *testing.T) { + rel := "integrations_example-config.yaml" + data, err := os.ReadFile(rel) + if err != nil { + t.Fatalf("read %s: %v", rel, err) + } + + var doc map[string]any + if err := yaml.Unmarshal(data, &doc); err != nil { + t.Fatalf("unmarshal %s: %v", rel, err) + } + + if path := findLegacyConfigKey(doc, legacyKeys, nil); path != "" { + t.Fatalf("found legacy config key %q in %s", path, rel) + } + }) + + t.Run("legacy connector example removed", func(t *testing.T) { + rel := filepath.Join("..", "..", "pkg", "connector", "integrations_example-config.yaml") + if _, err := os.Stat(rel); !os.IsNotExist(err) { + t.Fatalf("expected stale generic example %s to be removed, got err=%v", rel, err) + } + }) +} + +func findLegacyConfigKey(node any, legacyKeys map[string]struct{}, path []string) string { + switch value := node.(type) { + case map[string]any: + for key, child := range value { + if _, ok := legacyKeys[key]; ok { + return strings.Join(append(path, key), ".") + } + if found := findLegacyConfigKey(child, legacyKeys, append(path, key)); found != "" { + return found + } + } + case []any: + for idx, child := range value { + if found := findLegacyConfigKey(child, legacyKeys, append(path, fmt.Sprintf("[%d]", idx))); found != "" { + return found + } + } + } + return "" +} diff --git a/bridges/ai/login.go b/bridges/ai/login.go index 5f6f2b87..14abc785 100644 --- a/bridges/ai/login.go +++ b/bridges/ai/login.go @@ -220,11 +220,14 @@ func (ol *OpenAILogin) finishLogin(ctx context.Context, provider, apiKey, baseUR meta = &UserLoginMetadata{} } meta.Provider = provider - meta.APIKey = apiKey - meta.BaseURL = baseURL + creds := &LoginCredentials{ + APIKey: apiKey, + BaseURL: baseURL, + } if serviceTokens != nil && !serviceTokensEmpty(serviceTokens) { - meta.ServiceTokens = mergeServiceTokens(meta.ServiceTokens, serviceTokens) + creds.ServiceTokens = serviceTokens } + meta.Credentials = mergeLoginCredentials(meta.Credentials, creds) if err := ol.validateLoginMetadata(ctx, loginID, meta); err != nil { return nil, err } @@ -331,40 +334,6 @@ func (ol *OpenAILogin) validateLoginMetadata(ctx context.Context, loginID networ return nil } -func serviceTokensEmpty(tokens *ServiceTokens) bool { - if tokens == nil { - return true - } - if len(tokens.DesktopAPIInstances) > 0 { - for _, instance := range tokens.DesktopAPIInstances { - if strings.TrimSpace(instance.Token) != "" || strings.TrimSpace(instance.BaseURL) != "" { - return false - } - } - } - if len(tokens.MCPServers) > 0 { - for _, server := range tokens.MCPServers { - if strings.TrimSpace(server.Transport) != "" || - strings.TrimSpace(server.Endpoint) != "" || - strings.TrimSpace(server.Command) != "" || - len(server.Args) > 0 || - strings.TrimSpace(server.Token) != "" || - strings.TrimSpace(server.AuthURL) != "" || - strings.TrimSpace(server.AuthType) != "" || - strings.TrimSpace(server.Kind) != "" || - server.Connected { - return false - } - } - } - return strings.TrimSpace(tokens.OpenAI) == "" && - strings.TrimSpace(tokens.OpenRouter) == "" && - strings.TrimSpace(tokens.Exa) == "" && - strings.TrimSpace(tokens.Brave) == "" && - strings.TrimSpace(tokens.Perplexity) == "" && - strings.TrimSpace(tokens.DesktopAPI) == "" -} - func (ol *OpenAILogin) resolveCustomLogin(input map[string]string) (string, string, *ServiceTokens, error) { if input == nil { input = map[string]string{} diff --git a/bridges/ai/login_loaders.go b/bridges/ai/login_loaders.go index 189a6998..c18bcc49 100644 --- a/bridges/ai/login_loaders.go +++ b/bridges/ai/login_loaders.go @@ -34,13 +34,13 @@ func aiClientNeedsRebuild(existing *AIClient, key string, meta *UserLoginMetadat existingBaseURL := "" if existingMeta != nil { existingProvider = strings.TrimSpace(existingMeta.Provider) - existingBaseURL = stringutil.NormalizeBaseURL(existingMeta.BaseURL) + existingBaseURL = stringutil.NormalizeBaseURL(loginCredentialBaseURL(existingMeta)) } targetProvider := "" targetBaseURL := "" if meta != nil { targetProvider = strings.TrimSpace(meta.Provider) - targetBaseURL = stringutil.NormalizeBaseURL(meta.BaseURL) + targetBaseURL = stringutil.NormalizeBaseURL(loginCredentialBaseURL(meta)) } return existing.apiKey != key || !strings.EqualFold(existingProvider, targetProvider) || diff --git a/bridges/ai/login_loaders_test.go b/bridges/ai/login_loaders_test.go index ab02a1fc..2f3b57aa 100644 --- a/bridges/ai/login_loaders_test.go +++ b/bridges/ai/login_loaders_test.go @@ -27,19 +27,19 @@ func testUserLoginWithMeta(loginID networkid.UserLoginID, meta *UserLoginMetadat func TestAIClientNeedsRebuild(t *testing.T) { existing := &AIClient{ apiKey: "secret", - UserLogin: testUserLoginWithMeta("existing", &UserLoginMetadata{Provider: " OpenAI ", BaseURL: "https://api.example.com/v1/"}), + UserLogin: testUserLoginWithMeta("existing", &UserLoginMetadata{Provider: " OpenAI ", Credentials: &LoginCredentials{BaseURL: "https://api.example.com/v1/"}}), } - if aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openai", BaseURL: "https://api.example.com/v1"}) { + if aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openai", Credentials: &LoginCredentials{BaseURL: "https://api.example.com/v1"}}) { t.Fatal("expected no rebuild when key/provider/base URL are equivalent") } - if !aiClientNeedsRebuild(existing, "other-key", &UserLoginMetadata{Provider: "openai", BaseURL: "https://api.example.com/v1"}) { + if !aiClientNeedsRebuild(existing, "other-key", &UserLoginMetadata{Provider: "openai", Credentials: &LoginCredentials{BaseURL: "https://api.example.com/v1"}}) { t.Fatal("expected rebuild when API key changes") } - if !aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openrouter", BaseURL: "https://api.example.com/v1"}) { + if !aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openrouter", Credentials: &LoginCredentials{BaseURL: "https://api.example.com/v1"}}) { t.Fatal("expected rebuild when provider changes") } - if !aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openai", BaseURL: "https://api.other.example.com/v1"}) { + if !aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openai", Credentials: &LoginCredentials{BaseURL: "https://api.other.example.com/v1"}}) { t.Fatal("expected rebuild when base URL changes") } if !aiClientNeedsRebuild(nil, "secret", &UserLoginMetadata{Provider: "openai"}) { diff --git a/bridges/ai/magic_proxy_test.go b/bridges/ai/magic_proxy_test.go index 9e16c78c..f231482c 100644 --- a/bridges/ai/magic_proxy_test.go +++ b/bridges/ai/magic_proxy_test.go @@ -39,8 +39,10 @@ func TestResolveServiceConfigMagicProxyUsesJoinedPaths(t *testing.T) { oc := &OpenAIConnector{} meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } services := oc.resolveServiceConfig(meta) @@ -63,8 +65,10 @@ func TestResolveServiceConfigMagicProxyNoDuplicateOpenRouterPath(t *testing.T) { oc := &OpenAIConnector{} meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy/openrouter/v1", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy/openrouter/v1", + }, } services := oc.resolveServiceConfig(meta) @@ -87,7 +91,9 @@ func TestResolveExaProxyBaseURLMagicProxyPrefersLoginBase(t *testing.T) { } meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - BaseURL: "https://ai.bt.hn/", + Credentials: &LoginCredentials{ + BaseURL: "https://ai.bt.hn/", + }, } if got := oc.resolveExaProxyBaseURL(meta); got != "https://ai.bt.hn/exa" { t.Fatalf("unexpected exa proxy base: %q", got) diff --git a/bridges/ai/mcp_helpers.go b/bridges/ai/mcp_helpers.go index 88b5aa86..689e8144 100644 --- a/bridges/ai/mcp_helpers.go +++ b/bridges/ai/mcp_helpers.go @@ -66,24 +66,32 @@ func (oc *AIClient) verifyMCPServerConnection(ctx context.Context, server namedM } func setLoginMCPServer(meta *UserLoginMetadata, name string, cfg MCPServerConfig) { - if meta.ServiceTokens == nil { - meta.ServiceTokens = &ServiceTokens{} + creds := ensureLoginCredentials(meta) + if creds == nil { + return + } + if creds.ServiceTokens == nil { + creds.ServiceTokens = &ServiceTokens{} } - if meta.ServiceTokens.MCPServers == nil { - meta.ServiceTokens.MCPServers = map[string]MCPServerConfig{} + if creds.ServiceTokens.MCPServers == nil { + creds.ServiceTokens.MCPServers = map[string]MCPServerConfig{} } - meta.ServiceTokens.MCPServers[name] = normalizeMCPServerConfig(cfg) + creds.ServiceTokens.MCPServers[name] = normalizeMCPServerConfig(cfg) } func clearLoginMCPServer(meta *UserLoginMetadata, name string) { - if meta == nil || meta.ServiceTokens == nil || meta.ServiceTokens.MCPServers == nil { + creds := loginCredentials(meta) + if creds == nil || creds.ServiceTokens == nil || creds.ServiceTokens.MCPServers == nil { return } - delete(meta.ServiceTokens.MCPServers, name) - if len(meta.ServiceTokens.MCPServers) == 0 { - meta.ServiceTokens.MCPServers = nil + delete(creds.ServiceTokens.MCPServers, name) + if len(creds.ServiceTokens.MCPServers) == 0 { + creds.ServiceTokens.MCPServers = nil + } + if serviceTokensEmpty(creds.ServiceTokens) { + creds.ServiceTokens = nil } - if serviceTokensEmpty(meta.ServiceTokens) { - meta.ServiceTokens = nil + if loginCredentialsEmpty(creds) { + meta.Credentials = nil } } diff --git a/bridges/ai/mcp_servers.go b/bridges/ai/mcp_servers.go index 64c188ce..8f88460e 100644 --- a/bridges/ai/mcp_servers.go +++ b/bridges/ai/mcp_servers.go @@ -144,12 +144,12 @@ func (oc *AIClient) loginMCPServers() map[string]MCPServerConfig { if oc == nil || oc.UserLogin == nil { return nil } - meta := loginMetadata(oc.UserLogin) - if meta == nil || meta.ServiceTokens == nil || len(meta.ServiceTokens.MCPServers) == 0 { + tokens := loginCredentialServiceTokens(loginMetadata(oc.UserLogin)) + if tokens == nil || len(tokens.MCPServers) == 0 { return nil } - out := make(map[string]MCPServerConfig, len(meta.ServiceTokens.MCPServers)) - for rawName, rawCfg := range meta.ServiceTokens.MCPServers { + out := make(map[string]MCPServerConfig, len(tokens.MCPServers)) + for rawName, rawCfg := range tokens.MCPServers { name := normalizeMCPServerName(rawName) cfg := normalizeMCPServerConfig(rawCfg) if name == "" { diff --git a/bridges/ai/mcp_servers_test.go b/bridges/ai/mcp_servers_test.go index 3dfe93d8..973f5365 100644 --- a/bridges/ai/mcp_servers_test.go +++ b/bridges/ai/mcp_servers_test.go @@ -10,7 +10,7 @@ import ( ) func testAIClientWithMCPServers(servers map[string]MCPServerConfig) *AIClient { - meta := &UserLoginMetadata{ServiceTokens: &ServiceTokens{MCPServers: servers}} + meta := &UserLoginMetadata{Credentials: &LoginCredentials{ServiceTokens: &ServiceTokens{MCPServers: servers}}} login := &database.UserLogin{ID: networkid.UserLoginID("login"), Metadata: meta} userLogin := &bridgev2.UserLogin{UserLogin: login, Log: zerolog.Nop()} return &AIClient{UserLogin: userLogin} diff --git a/bridges/ai/media_understanding_runner_openai_test.go b/bridges/ai/media_understanding_runner_openai_test.go index 745307a4..2d784cf5 100644 --- a/bridges/ai/media_understanding_runner_openai_test.go +++ b/bridges/ai/media_understanding_runner_openai_test.go @@ -26,8 +26,10 @@ func TestResolveMediaProviderAPIKeyOpenAIMagicProxyUsesLoginToken(t *testing.T) meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } client := newMediaTestClient(meta, &OpenAIConnector{}) @@ -39,8 +41,10 @@ func TestResolveMediaProviderAPIKeyOpenAIMagicProxyUsesLoginToken(t *testing.T) func TestResolveOpenAIMediaBaseURLMagicProxyUsesOpenAIServicePath(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } client := newMediaTestClient(meta, &OpenAIConnector{}) diff --git a/bridges/ai/messages_responses_input_test.go b/bridges/ai/messages_responses_input_test.go index 9a3a09ee..715889ea 100644 --- a/bridges/ai/messages_responses_input_test.go +++ b/bridges/ai/messages_responses_input_test.go @@ -7,7 +7,7 @@ import ( ) func TestPromptContextToResponsesInput_MultimodalUser(t *testing.T) { - input := PromptContextToResponsesInput(UserPromptContext( + input := promptContextToResponsesInput(UserPromptContext( PromptBlock{Type: PromptBlockText, Text: "hello"}, PromptBlock{Type: PromptBlockImage, ImageB64: "aGVsbG8=", MimeType: "image/png"}, )) diff --git a/bridges/ai/metadata.go b/bridges/ai/metadata.go index e6492b93..149e6995 100644 --- a/bridges/ai/metadata.go +++ b/bridges/ai/metadata.go @@ -4,6 +4,7 @@ import ( "encoding/json" "maps" "slices" + "strings" "go.mau.fi/util/jsontime" "go.mau.fi/util/random" @@ -53,6 +54,13 @@ type UserProfile struct { CustomInstructions string `json:"custom_instructions,omitempty"` } +// LoginCredentials stores the per-login credentials and service-specific tokens. +type LoginCredentials struct { + APIKey string `json:"api_key,omitempty"` + BaseURL string `json:"base_url,omitempty"` + ServiceTokens *ServiceTokens `json:"service_tokens,omitempty"` +} + // ServiceTokens stores optional per-login credentials for external services. type ServiceTokens struct { OpenAI string `json:"openai,omitempty"` @@ -109,26 +117,22 @@ type BuiltinAlwaysAllowRule struct { // UserLoginMetadata is stored on each login row to keep per-user settings. type UserLoginMetadata struct { - Provider string `json:"provider,omitempty"` // Selected provider (beeper, openai, openrouter) - APIKey string `json:"api_key,omitempty"` - BaseURL string `json:"base_url,omitempty"` // Per-user API endpoint - TitleGenerationModel string `json:"title_generation_model,omitempty"` // Model to use for generating chat titles - Agents *bool `json:"agents,omitempty"` // Nil/true enables agents, false limits login to model rooms - NextChatIndex int `json:"next_chat_index,omitempty"` - DefaultChatPortalID string `json:"default_chat_portal_id,omitempty"` - ModelCache *ModelCache `json:"model_cache,omitempty"` - ChatsSynced bool `json:"chats_synced,omitempty"` // True after initial bootstrap completed successfully - Gravatar *GravatarState `json:"gravatar,omitempty"` - Timezone string `json:"timezone,omitempty"` - Profile *UserProfile `json:"profile,omitempty"` + Provider string `json:"provider,omitempty"` // Selected provider (beeper, openai, openrouter) + Credentials *LoginCredentials `json:"credentials,omitempty"` + TitleGenerationModel string `json:"title_generation_model,omitempty"` // Model to use for generating chat titles + Agents *bool `json:"agents,omitempty"` // Nil/true enables agents, false limits login to model rooms + NextChatIndex int `json:"next_chat_index,omitempty"` + DefaultChatPortalID string `json:"default_chat_portal_id,omitempty"` + ModelCache *ModelCache `json:"model_cache,omitempty"` + ChatsSynced bool `json:"chats_synced,omitempty"` // True after initial bootstrap completed successfully + Gravatar *GravatarState `json:"gravatar,omitempty"` + Timezone string `json:"timezone,omitempty"` + Profile *UserProfile `json:"profile,omitempty"` // FileAnnotationCache stores parsed PDF content from OpenRouter's file-parser plugin // Key is the file hash (SHA256), pruned after 7 days FileAnnotationCache map[string]FileAnnotation `json:"file_annotation_cache,omitempty"` - // Optional per-login tokens for external services - ServiceTokens *ServiceTokens `json:"service_tokens,omitempty"` - // Tool approval rules (e.g. "always allow" decisions for MCP approvals or dangerous builtin tools). ToolApprovals *ToolApprovalsConfig `json:"tool_approvals,omitempty"` @@ -146,6 +150,128 @@ type UserLoginMetadata struct { LastErrorAt int64 `json:"last_error_at,omitempty"` // Unix timestamp } +func loginCredentials(meta *UserLoginMetadata) *LoginCredentials { + if meta == nil { + return nil + } + return meta.Credentials +} + +func ensureLoginCredentials(meta *UserLoginMetadata) *LoginCredentials { + if meta == nil { + return nil + } + if meta.Credentials == nil { + meta.Credentials = &LoginCredentials{} + } + return meta.Credentials +} + +func loginCredentialAPIKey(meta *UserLoginMetadata) string { + if creds := loginCredentials(meta); creds != nil { + return strings.TrimSpace(creds.APIKey) + } + return "" +} + +func loginCredentialBaseURL(meta *UserLoginMetadata) string { + if creds := loginCredentials(meta); creds != nil { + return strings.TrimSpace(creds.BaseURL) + } + return "" +} + +func loginCredentialServiceTokens(meta *UserLoginMetadata) *ServiceTokens { + if creds := loginCredentials(meta); creds != nil { + return creds.ServiceTokens + } + return nil +} + +func loginCredentialsEmpty(creds *LoginCredentials) bool { + if creds == nil { + return true + } + return strings.TrimSpace(creds.APIKey) == "" && + strings.TrimSpace(creds.BaseURL) == "" && + serviceTokensEmpty(creds.ServiceTokens) +} + +func cloneServiceTokens(src *ServiceTokens) *ServiceTokens { + if src == nil { + return nil + } + clone := *src + if src.DesktopAPIInstances != nil { + clone.DesktopAPIInstances = maps.Clone(src.DesktopAPIInstances) + } + if src.MCPServers != nil { + clone.MCPServers = maps.Clone(src.MCPServers) + } + return &clone +} + +func mergeLoginCredentials(existing, incoming *LoginCredentials) *LoginCredentials { + if incoming == nil { + return existing + } + if existing == nil { + clone := *incoming + clone.ServiceTokens = cloneServiceTokens(incoming.ServiceTokens) + if loginCredentialsEmpty(&clone) { + return nil + } + return &clone + } + + merged := *existing + if strings.TrimSpace(incoming.APIKey) != "" { + merged.APIKey = incoming.APIKey + } + if strings.TrimSpace(incoming.BaseURL) != "" { + merged.BaseURL = incoming.BaseURL + } + merged.ServiceTokens = mergeServiceTokens(existing.ServiceTokens, incoming.ServiceTokens) + if loginCredentialsEmpty(&merged) { + return nil + } + return &merged +} + +func serviceTokensEmpty(tokens *ServiceTokens) bool { + if tokens == nil { + return true + } + if len(tokens.DesktopAPIInstances) > 0 { + for _, instance := range tokens.DesktopAPIInstances { + if strings.TrimSpace(instance.Token) != "" || strings.TrimSpace(instance.BaseURL) != "" { + return false + } + } + } + if len(tokens.MCPServers) > 0 { + for _, server := range tokens.MCPServers { + if strings.TrimSpace(server.Transport) != "" || + strings.TrimSpace(server.Endpoint) != "" || + strings.TrimSpace(server.Command) != "" || + len(server.Args) > 0 || + strings.TrimSpace(server.Token) != "" || + strings.TrimSpace(server.AuthURL) != "" || + strings.TrimSpace(server.AuthType) != "" || + strings.TrimSpace(server.Kind) != "" || + server.Connected { + return false + } + } + } + return strings.TrimSpace(tokens.OpenAI) == "" && + strings.TrimSpace(tokens.OpenRouter) == "" && + strings.TrimSpace(tokens.Exa) == "" && + strings.TrimSpace(tokens.Brave) == "" && + strings.TrimSpace(tokens.Perplexity) == "" && + strings.TrimSpace(tokens.DesktopAPI) == "" +} + // HeartbeatState tracks last heartbeat delivery for dedupe. type HeartbeatState struct { LastHeartbeatText string `json:"last_heartbeat_text,omitempty"` diff --git a/bridges/ai/model_catalog_test.go b/bridges/ai/model_catalog_test.go index 810817df..02e21fd0 100644 --- a/bridges/ai/model_catalog_test.go +++ b/bridges/ai/model_catalog_test.go @@ -10,7 +10,9 @@ func TestImplicitModelCatalogEntries_MagicProxySeedsCatalog(t *testing.T) { // Magic Proxy logins store the API key on the login metadata. meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "mp-token", + Credentials: &LoginCredentials{ + APIKey: "mp-token", + }, } entries := oc.implicitModelCatalogEntries(meta) diff --git a/bridges/ai/prompt_context_local.go b/bridges/ai/prompt_context_local.go index fccf8944..b243e23b 100644 --- a/bridges/ai/prompt_context_local.go +++ b/bridges/ai/prompt_context_local.go @@ -47,7 +47,7 @@ func resolveBlockImageURL(block PromptBlock) string { return imageURL } -func PromptContextToResponsesInput(ctx PromptContext) responses.ResponseInputParam { +func promptContextToResponsesInput(ctx PromptContext) responses.ResponseInputParam { var result responses.ResponseInputParam for _, msg := range ctx.Messages { result = append(result, promptMessageToResponsesInputs(msg)...) @@ -123,7 +123,7 @@ func promptMessageToResponsesInputs(msg PromptMessage) responses.ResponseInputPa } } -func PromptContextToChatCompletionMessages(ctx PromptContext, supportsVideoURL bool) []openai.ChatCompletionMessageParamUnion { +func promptContextToChatCompletionMessages(ctx PromptContext, supportsVideoURL bool) []openai.ChatCompletionMessageParamUnion { var messages []openai.ChatCompletionMessageParamUnion if strings.TrimSpace(ctx.SystemPrompt) != "" { messages = append(messages, openai.SystemMessage(strings.TrimSpace(ctx.SystemPrompt))) @@ -238,7 +238,7 @@ func promptToolToChatMessage(msg PromptMessage) *openai.ChatCompletionToolMessag } } -func ChatMessagesToPromptContext(messages []openai.ChatCompletionMessageParamUnion) PromptContext { +func chatMessagesToPromptContext(messages []openai.ChatCompletionMessageParamUnion) PromptContext { var ctx PromptContext for _, msg := range messages { appendChatMessageToPromptContext(&ctx, msg) @@ -357,7 +357,7 @@ func inferPromptMimeTypeFromDataURL(value string) string { return rest[:idx] } -func HasUnsupportedResponsesPromptContext(ctx PromptContext) bool { +func hasUnsupportedResponsesPromptContext(ctx PromptContext) bool { for _, msg := range ctx.Messages { for _, block := range msg.Blocks { switch block.Type { diff --git a/bridges/ai/provider_openai_chat.go b/bridges/ai/provider_openai_chat.go index f16f0e1a..a9a6a260 100644 --- a/bridges/ai/provider_openai_chat.go +++ b/bridges/ai/provider_openai_chat.go @@ -9,7 +9,7 @@ import ( ) func (o *OpenAIProvider) generateChatCompletions(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { - chatMessages := PromptContextToChatCompletionMessages(params.Context, isOpenRouterBaseURL(o.baseURL)) + chatMessages := promptContextToChatCompletionMessages(params.Context, isOpenRouterBaseURL(o.baseURL)) if len(chatMessages) == 0 { return nil, errors.New("no chat messages for completion") } diff --git a/bridges/ai/provider_openai_responses.go b/bridges/ai/provider_openai_responses.go index d934c5a0..c0f83dde 100644 --- a/bridges/ai/provider_openai_responses.go +++ b/bridges/ai/provider_openai_responses.go @@ -22,7 +22,7 @@ func (o *OpenAIProvider) buildResponsesParams(params GenerateParams) responses.R responsesParams := responses.ResponseNewParams{ Model: params.Model, Input: responses.ResponseNewParamsInputUnion{ - OfInputItemList: PromptContextToResponsesInput(params.Context), + OfInputItemList: promptContextToResponsesInput(params.Context), }, } @@ -58,7 +58,7 @@ func (o *OpenAIProvider) buildResponsesParams(params GenerateParams) responses.R // GenerateStream generates a streaming response from OpenAI using the Responses API. func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GenerateParams) (<-chan StreamEvent, error) { - if HasUnsupportedResponsesPromptContext(params.Context) { + if hasUnsupportedResponsesPromptContext(params.Context) { return nil, fmt.Errorf("responses API does not support prompt context block types required by this request") } @@ -148,7 +148,7 @@ func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GeneratePara // Generate performs a non-streaming generation using the Responses API. func (o *OpenAIProvider) Generate(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { - if HasUnsupportedResponsesPromptContext(params.Context) { + if hasUnsupportedResponsesPromptContext(params.Context) { return nil, fmt.Errorf("responses API does not support prompt context block types required by this request") } diff --git a/bridges/ai/provisioning.go b/bridges/ai/provisioning.go index 6d3e5caf..9d458d0e 100644 --- a/bridges/ai/provisioning.go +++ b/bridges/ai/provisioning.go @@ -541,11 +541,15 @@ func resolveNamedMCPServer(client *AIClient, name string) (namedMCPServer, error } func ensureLoginMCPServer(meta *UserLoginMetadata) { - if meta.ServiceTokens == nil { - meta.ServiceTokens = &ServiceTokens{} + creds := ensureLoginCredentials(meta) + if creds == nil { + return + } + if creds.ServiceTokens == nil { + creds.ServiceTokens = &ServiceTokens{} } - if meta.ServiceTokens.MCPServers == nil { - meta.ServiceTokens.MCPServers = map[string]MCPServerConfig{} + if creds.ServiceTokens.MCPServers == nil { + creds.ServiceTokens.MCPServers = map[string]MCPServerConfig{} } } @@ -584,7 +588,12 @@ func (api *ProvisioningAPI) handleCreateMCPServer(w http.ResponseWriter, r *http } meta := loginMetadata(login) ensureLoginMCPServer(meta) - if _, exists := meta.ServiceTokens.MCPServers[name]; exists { + tokens := loginCredentialServiceTokens(meta) + if tokens == nil { + mautrix.MUnknown.WithMessage("Couldn't load MCP servers for this login.").Write(w) + return + } + if _, exists := tokens.MCPServers[name]; exists { mautrix.MInvalidParam.WithMessage("MCP server %s already exists.", name).Write(w) return } diff --git a/bridges/ai/response_retry.go b/bridges/ai/response_retry.go index 305bf8a4..f742c99d 100644 --- a/bridges/ai/response_retry.go +++ b/bridges/ai/response_retry.go @@ -340,7 +340,7 @@ func (oc *AIClient) runCompactionFlushHook( Client: oc, Portal: portal, Meta: meta, - Prompt: PromptContextToChatCompletionMessages(prompt, false), + Prompt: promptContextToChatCompletionMessages(prompt, false), RequestedTokens: cle.RequestedTokens, ModelMaxTokens: cle.ModelMaxTokens, Attempt: attempt, @@ -366,7 +366,7 @@ func (oc *AIClient) runAgentLoopWithRetry( } func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext PromptContext) (responseFuncCanonical, string) { - if HasUnsupportedResponsesPromptContext(promptContext) { + if hasUnsupportedResponsesPromptContext(promptContext) { return oc.runChatCompletionsAgentLoopPrompt, "chat_completions" } modelID := "" @@ -417,7 +417,7 @@ func (oc *AIClient) runtimeCompactOnOverflow( requestedTokens int, currentPromptTokens int, ) (PromptContext, airuntime.CompactionDecision, bool) { - serialized := PromptContextToChatCompletionMessages(prompt, false) + serialized := promptContextToChatCompletionMessages(prompt, false) result := airuntime.CompactPromptOnOverflow(airuntime.OverflowCompactionInput{ Prompt: serialized, ContextWindowTokens: contextWindowTokens, @@ -432,7 +432,7 @@ func (oc *AIClient) runtimeCompactOnOverflow( MaxHistoryShare: oc.pruningMaxHistoryShare(), ProtectedTail: 3, }) - return ChatMessagesToPromptContext(result.Prompt), result.Decision, result.Success + return chatMessagesToPromptContext(result.Prompt), result.Decision, result.Success } func (oc *AIClient) truncateOversizedToolResultsForOverflow( diff --git a/bridges/ai/runtime_compaction_adapter.go b/bridges/ai/runtime_compaction_adapter.go index d2155b8e..165cb0ac 100644 --- a/bridges/ai/runtime_compaction_adapter.go +++ b/bridges/ai/runtime_compaction_adapter.go @@ -165,5 +165,5 @@ func estimatePromptTokensForModel(prompt []openai.ChatCompletionMessageParamUnio } func estimatePromptContextTokensForModel(prompt PromptContext, model string) int { - return estimatePromptTokensForModel(PromptContextToChatCompletionMessages(prompt, false), model) + return estimatePromptTokensForModel(promptContextToChatCompletionMessages(prompt, false), model) } diff --git a/bridges/ai/streaming_chat_completions.go b/bridges/ai/streaming_chat_completions.go index 4d38105f..4f1c3ea6 100644 --- a/bridges/ai/streaming_chat_completions.go +++ b/bridges/ai/streaming_chat_completions.go @@ -31,7 +31,7 @@ func (a *chatCompletionsTurnAdapter) handleStreamStepError( if cle := ParseContextLengthError(stepErr); cle != nil { return cle, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, a.state, a.meta, "context-length", stepErr) } - logChatCompletionsFailure(a.log, stepErr, params, a.meta, currentMessages, "stream_err") + logChatCompletionsFailure(a.log, stepErr, params, a.meta, a.prompt, "stream_err") return nil, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, a.state, a.meta, "error", stepErr) } @@ -48,14 +48,14 @@ func (a *chatCompletionsTurnAdapter) RunAgentTurn( typingSignals := a.typingSignals touchTyping := a.touchTyping isHeartbeat := a.isHeartbeat - currentMessages := PromptContextToChatCompletionMessages(a.prompt, oc.isOpenRouterProvider()) + currentMessages := promptContextToChatCompletionMessages(a.prompt, oc.isOpenRouterProvider()) params := oc.buildChatCompletionsAgentLoopParams(ctx, meta, currentMessages) stream := oc.api.Chat.Completions.NewStreaming(ctx, params) if stream == nil { initErr := errors.New("chat completions streaming not available") - logChatCompletionsFailure(log, initErr, params, meta, currentMessages, "stream_init") + logChatCompletionsFailure(log, initErr, params, meta, a.prompt, "stream_init") return false, nil, oc.finishStreamingWithFailure(ctx, log, portal, state, meta, "error", initErr) } diff --git a/bridges/ai/streaming_continuation.go b/bridges/ai/streaming_continuation.go index 93540f84..24ad4ffa 100644 --- a/bridges/ai/streaming_continuation.go +++ b/bridges/ai/streaming_continuation.go @@ -11,17 +11,18 @@ import ( // and/or after responding to tool approval requests. func (oc *AIClient) buildContinuationParams( ctx context.Context, + prompt *PromptContext, state *streamingState, meta *PortalMetadata, pendingOutputs []functionCallOutput, approvalInputs []responses.ResponseInputItemUnionParam, ) responses.ResponseNewParams { - // Build function call outputs as input - var input responses.ResponseInputParam - if len(state.baseInput) > 0 { - // All Responses continuations are stateless: include the accumulated local history. - input = append(input, state.baseInput...) + currentPrompt := PromptContext{} + if prompt != nil { + currentPrompt = ClonePromptContext(*prompt) } + var input responses.ResponseInputParam + input = append(input, promptContextToResponsesInput(currentPrompt)...) input = append(input, approvalInputs...) for _, output := range pendingOutputs { if output.name != "" { @@ -38,13 +39,20 @@ func (oc *AIClient) buildContinuationParams( steerPrompts = oc.getSteeringMessages(state.roomID) } if len(steerPrompts) > 0 { + steeringMessages := buildSteeringPromptMessages(steerPrompts) + if prompt != nil && len(steeringMessages) > 0 { + prompt.Messages = append(prompt.Messages, steeringMessages...) + } steerInput := oc.buildSteeringInputItems(steerPrompts, meta) if len(steerInput) > 0 { input = append(input, steerInput...) - state.baseInput = append(state.baseInput, steerInput...) } } - return oc.buildResponsesAgentLoopParams(ctx, meta, state.baseSystemPrompt, input, true) + systemPrompt := "" + if prompt != nil { + systemPrompt = prompt.SystemPrompt + } + return oc.buildResponsesAgentLoopParams(ctx, meta, systemPrompt, input, true) } func (oc *AIClient) buildSteeringInputItems(prompts []string, meta *PortalMetadata) responses.ResponseInputParam { @@ -57,7 +65,7 @@ func (oc *AIClient) buildSteeringInputItems(prompts []string, meta *PortalMetada if prompt == "" { continue } - input = append(input, PromptContextToResponsesInput(UserPromptContext( + input = append(input, promptContextToResponsesInput(UserPromptContext( PromptBlock{Type: PromptBlockText, Text: prompt}, ))...) } diff --git a/bridges/ai/streaming_executor.go b/bridges/ai/streaming_executor.go index 9d20ba86..22aaa101 100644 --- a/bridges/ai/streaming_executor.go +++ b/bridges/ai/streaming_executor.go @@ -74,8 +74,7 @@ func (oc *AIClient) runAgentLoop( prompt PromptContext, newProvider func(prep streamingRunPrep, prompt PromptContext) agentLoopProvider, ) (bool, *ContextLengthError, error) { - messages := PromptContextToChatCompletionMessages(prompt, oc.isOpenRouterProvider()) - prep, _, typingCleanup := oc.prepareStreamingRun(ctx, log, evt, portal, meta, messages) + prep, typingCleanup := oc.prepareStreamingRun(ctx, log, evt, portal, meta) defer typingCleanup() state := prep.State diff --git a/bridges/ai/streaming_init.go b/bridges/ai/streaming_init.go index dc66cf84..28862cb8 100644 --- a/bridges/ai/streaming_init.go +++ b/bridges/ai/streaming_init.go @@ -3,7 +3,6 @@ package ai import ( "context" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" @@ -103,8 +102,7 @@ func (oc *AIClient) prepareStreamingRun( evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, -) (prep streamingRunPrep, pruned []openai.ChatCompletionMessageParamUnion, cleanup func()) { +) (prep streamingRunPrep, cleanup func()) { var sourceEventID id.EventID senderID := "" if evt != nil { @@ -176,13 +174,11 @@ func (oc *AIClient) prepareStreamingRun( } } - pruned = messages - prep = streamingRunPrep{ State: state, TypingSignals: typingSignals, TouchTyping: touchTyping, IsHeartbeat: isHeartbeat, } - return prep, pruned, cleanup + return prep, cleanup } diff --git a/bridges/ai/streaming_init_test.go b/bridges/ai/streaming_init_test.go index dac03c14..34cb5d67 100644 --- a/bridges/ai/streaming_init_test.go +++ b/bridges/ai/streaming_init_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" @@ -29,13 +28,12 @@ func TestPrepareStreamingRun_ModelRoomKeepsReplyTarget(t *testing.T) { }, } - prep, _, cleanup := oc.prepareStreamingRun( + prep, cleanup := oc.prepareStreamingRun( context.Background(), zerolog.Nop(), evt, nil, meta, - []openai.ChatCompletionMessageParamUnion{}, ) defer cleanup() @@ -64,13 +62,12 @@ func TestPrepareStreamingRun_AgentRoomKeepsReplyTarget(t *testing.T) { }, } - prep, _, cleanup := oc.prepareStreamingRun( + prep, cleanup := oc.prepareStreamingRun( context.Background(), zerolog.Nop(), evt, nil, meta, - []openai.ChatCompletionMessageParamUnion{}, ) defer cleanup() @@ -94,13 +91,12 @@ func TestPrepareStreamingRun_SnapshotsResponderFields(t *testing.T) { } meta := modelModeTestMeta("openai/gpt-5.2") - prep, _, cleanup := oc.prepareStreamingRun( + prep, cleanup := oc.prepareStreamingRun( context.Background(), zerolog.Nop(), nil, nil, meta, - []openai.ChatCompletionMessageParamUnion{}, ) defer cleanup() diff --git a/bridges/ai/streaming_input_conversion.go b/bridges/ai/streaming_input_conversion.go index 010056e9..72536264 100644 --- a/bridges/ai/streaming_input_conversion.go +++ b/bridges/ai/streaming_input_conversion.go @@ -1,31 +1,15 @@ package ai -import ( - "github.com/openai/openai-go/v3" -) - -// hasAudioContent checks if the prompt contains audio content -func hasAudioContent(messages []openai.ChatCompletionMessageParamUnion) bool { - for _, msg := range messages { - if msg.OfUser != nil && len(msg.OfUser.Content.OfArrayOfContentParts) > 0 { - for _, part := range msg.OfUser.Content.OfArrayOfContentParts { - if part.OfInputAudio != nil { - return true - } - } - } - } +func promptHasAudioContent(prompt PromptContext) bool { + _ = prompt return false } -// hasMultimodalContent checks if the prompt contains non-text content (image, file, audio). -func hasMultimodalContent(messages []openai.ChatCompletionMessageParamUnion) bool { - for _, msg := range messages { - if msg.OfUser != nil && len(msg.OfUser.Content.OfArrayOfContentParts) > 0 { - for _, part := range msg.OfUser.Content.OfArrayOfContentParts { - if part.OfImageURL != nil || part.OfFile != nil || part.OfInputAudio != nil { - return true - } +func promptHasMultimodalContent(prompt PromptContext) bool { + for _, msg := range prompt.Messages { + for _, block := range msg.Blocks { + if block.Type == PromptBlockImage { + return true } } } diff --git a/bridges/ai/streaming_responses_api.go b/bridges/ai/streaming_responses_api.go index 0dd014ff..ab5e25bb 100644 --- a/bridges/ai/streaming_responses_api.go +++ b/bridges/ai/streaming_responses_api.go @@ -37,7 +37,7 @@ func (a *responsesTurnAdapter) TrackRoomRunStreaming() bool { func (a *responsesTurnAdapter) startInitialRound(ctx context.Context) (*ssestream.Stream[responses.ResponseStreamEventUnion], error) { if !a.initialized { - input := PromptContextToResponsesInput(a.prompt) + input := promptContextToResponsesInput(a.prompt) a.params = a.oc.buildResponsesAgentLoopParams(ctx, a.meta, a.prompt.SystemPrompt, input, false) if len(a.params.Tools) > 0 { zerolog.Ctx(ctx).Debug().Int("count", len(a.params.Tools)).Msg("Added streaming turn tools") @@ -45,16 +45,12 @@ func (a *responsesTurnAdapter) startInitialRound(ctx context.Context) (*ssestrea if a.oc.isOpenRouterProvider() { ctx = WithPDFEngine(ctx, a.oc.effectivePDFEngine(a.meta)) } - a.state.baseSystemPrompt = a.prompt.SystemPrompt a.initialized = true } stream := a.oc.api.Responses.NewStreaming(ctx, a.params) if stream == nil { return nil, errors.New("responses streaming not available") } - if a.params.Input.OfInputItemList != nil { - a.state.baseInput = a.params.Input.OfInputItemList - } return stream, nil } @@ -89,10 +85,7 @@ func (a *responsesTurnAdapter) startContinuationRound(ctx context.Context) (*sse approvalInputs = append(approvalInputs, item) } - continuationParams := a.oc.buildContinuationParams(ctx, state, a.meta, pendingOutputs, approvalInputs) - if continuationInput := continuationParams.Input.OfInputItemList; continuationInput != nil { - state.baseInput = slices.Clone(continuationInput) - } + continuationParams := a.oc.buildContinuationParams(ctx, &a.prompt, state, a.meta, pendingOutputs, approvalInputs) state.needsTextSeparator = true stream := a.oc.api.Responses.NewStreaming(ctx, continuationParams) @@ -120,7 +113,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( stream, err = a.startInitialRound(ctx) params = a.params if err != nil { - logResponsesFailure(a.log, err, params, a.meta, PromptContextToChatCompletionMessages(a.prompt, a.oc.isOpenRouterProvider()), "stream_init") + logResponsesFailure(a.log, err, params, a.meta, a.prompt, "stream_init") return false, nil, &PreDeltaError{Err: err} } } else { @@ -135,14 +128,14 @@ func (a *responsesTurnAdapter) RunAgentTurn( a.log.Debug(). Int("pending_outputs", len(state.pendingFunctionOutputs)). Int("pending_approvals", len(state.pendingMcpApprovals)). - Int("base_input_items", len(state.baseInput)). + Int("prompt_messages", len(a.prompt.Messages)). Msg("Continuing stateless response with pending tool actions") stream, params, err = a.startContinuationRound(ctx) if err != nil { if errors.Is(err, context.Canceled) { return false, nil, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, state, a.meta, "cancelled", err) } - logResponsesFailure(a.log, err, params, a.meta, PromptContextToChatCompletionMessages(a.prompt, a.oc.isOpenRouterProvider()), "continuation_init") + logResponsesFailure(a.log, err, params, a.meta, a.prompt, "continuation_init") return false, nil, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, state, a.meta, "error", err) } } @@ -158,7 +151,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( if round > 0 { stage = "continuation_event_error" } - logResponsesFailure(a.log, evtErr, params, a.meta, PromptContextToChatCompletionMessages(a.prompt, a.oc.isOpenRouterProvider()), stage) + logResponsesFailure(a.log, evtErr, params, a.meta, a.prompt, stage) } return done, cle, evtErr }, @@ -167,7 +160,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( if round > 0 { stage = "continuation_err" } - logResponsesFailure(a.log, stepErr, params, a.meta, PromptContextToChatCompletionMessages(a.prompt, a.oc.isOpenRouterProvider()), stage) + logResponsesFailure(a.log, stepErr, params, a.meta, a.prompt, stage) return a.oc.handleResponsesStreamErr(ctx, a.portal, state, a.meta, stepErr, round == 0) }, ) @@ -190,7 +183,6 @@ func (a *responsesTurnAdapter) ContinueAgentLoop(messages []PromptMessage) { return } a.prompt.Messages = append(a.prompt.Messages, messages...) - a.state.baseInput = append(a.state.baseInput, PromptContextToResponsesInput(PromptContext{Messages: messages})...) a.hasFollowUp = true } diff --git a/bridges/ai/streaming_responses_input_test.go b/bridges/ai/streaming_responses_input_test.go index 6e950286..7a630800 100644 --- a/bridges/ai/streaming_responses_input_test.go +++ b/bridges/ai/streaming_responses_input_test.go @@ -14,7 +14,7 @@ func TestConvertToResponsesInput_RolesAndToolOutput(t *testing.T) { openai.ToolMessage("tool output", "call_123"), } - input := PromptContextToResponsesInput(ChatMessagesToPromptContext(messages)) + input := promptContextToResponsesInput(chatMessagesToPromptContext(messages)) if len(input) != 2 { t.Fatalf("expected 2 input items, got %d", len(input)) } @@ -53,7 +53,7 @@ func TestConvertToResponsesInput_AssistantToolCalls(t *testing.T) { }, }} - input := PromptContextToResponsesInput(ChatMessagesToPromptContext(messages)) + input := promptContextToResponsesInput(chatMessagesToPromptContext(messages)) if len(input) != 2 { t.Fatalf("expected 2 input items, got %d", len(input)) } diff --git a/bridges/ai/streaming_state.go b/bridges/ai/streaming_state.go index ac688887..18990d95 100644 --- a/bridges/ai/streaming_state.go +++ b/bridges/ai/streaming_state.go @@ -38,8 +38,6 @@ type streamingState struct { reasoningTokens int64 totalTokens int64 - baseSystemPrompt string - baseInput responses.ResponseInputParam accumulated strings.Builder reasoning strings.Builder toolCalls []ToolCallMetadata diff --git a/bridges/ai/token_resolver.go b/bridges/ai/token_resolver.go index abacc936..7addbea0 100644 --- a/bridges/ai/token_resolver.go +++ b/bridges/ai/token_resolver.go @@ -108,10 +108,8 @@ func (oc *OpenAIConnector) resolveProxyRoot(meta *UserLoginMetadata) string { if oc == nil { return "" } - if meta != nil { - if raw := strings.TrimSpace(meta.BaseURL); raw != "" { - return normalizeProxyBaseURL(raw) - } + if raw := loginCredentialBaseURL(meta); raw != "" { + return normalizeProxyBaseURL(raw) } return "" } @@ -147,9 +145,9 @@ func (oc *OpenAIConnector) resolveServiceConfig(meta *UserLoginMetadata) Service } if meta.Provider == ProviderMagicProxy { - base := normalizeProxyBaseURL(meta.BaseURL) + base := normalizeProxyBaseURL(loginCredentialBaseURL(meta)) if base != "" { - token := trimToken(meta.APIKey) + token := trimToken(loginCredentialAPIKey(meta)) services[serviceOpenRouter] = ServiceConfig{ BaseURL: joinProxyPath(base, "/openrouter/v1"), APIKey: token, @@ -190,34 +188,34 @@ func (oc *OpenAIConnector) resolveProviderAPIKey(meta *UserLoginMetadata) string } switch meta.Provider { case ProviderMagicProxy: - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenRouter) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenRouter) } case ProviderOpenRouter: if key := trimToken(oc.modelProviderConfig(ProviderOpenRouter).APIKey); key != "" { return key } - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenRouter) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenRouter) } case ProviderOpenAI: if key := trimToken(oc.modelProviderConfig(ProviderOpenAI).APIKey); key != "" { return key } - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenAI) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenAI) } default: - return trimToken(meta.APIKey) + return trimToken(loginCredentialAPIKey(meta)) } return "" } @@ -230,12 +228,12 @@ func (oc *OpenAIConnector) resolveOpenAIAPIKey(meta *UserLoginMetadata) string { return "" } if meta.Provider == ProviderOpenAI { - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenAI) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenAI) } return "" } @@ -248,15 +246,15 @@ func (oc *OpenAIConnector) resolveOpenRouterAPIKey(meta *UserLoginMetadata) stri return "" } if meta.Provider == ProviderOpenRouter { - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } } if meta.Provider == ProviderMagicProxy { - return trimToken(meta.APIKey) + return trimToken(loginCredentialAPIKey(meta)) } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenRouter) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenRouter) } return "" } @@ -268,21 +266,21 @@ func loginTokenForService(meta *UserLoginMetadata, service string) string { switch service { case serviceOpenAI: if meta.Provider == ProviderOpenAI { - return trimToken(meta.APIKey) + return trimToken(loginCredentialAPIKey(meta)) } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenAI) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenAI) } case serviceOpenRouter: if meta.Provider == ProviderOpenRouter || meta.Provider == ProviderMagicProxy { - return trimToken(meta.APIKey) + return trimToken(loginCredentialAPIKey(meta)) } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenRouter) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenRouter) } case serviceExa: - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.Exa) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.Exa) } } return "" diff --git a/bridges/ai/tools.go b/bridges/ai/tools.go index 84a59f3a..401e845e 100644 --- a/bridges/ai/tools.go +++ b/bridges/ai/tools.go @@ -1361,7 +1361,7 @@ func resolveOpenAITTSBaseURL(btc *BridgeToolContext, providerBaseURL string) (st } } } - if root := normalizeProxyBaseURL(meta.BaseURL); root != "" { + if root := normalizeProxyBaseURL(loginCredentialBaseURL(meta)); root != "" { return joinProxyPath(root, "/openai/v1"), true } diff --git a/bridges/ai/tools_search_fetch.go b/bridges/ai/tools_search_fetch.go index e610e884..25b8548d 100644 --- a/bridges/ai/tools_search_fetch.go +++ b/bridges/ai/tools_search_fetch.go @@ -213,7 +213,7 @@ func applyExaProxyDefaultsTo(baseURL *string, apiKey *string, meta *UserLoginMet } if *apiKey == "" { if meta != nil && meta.Provider == ProviderMagicProxy { - if token := strings.TrimSpace(meta.APIKey); token != "" { + if token := loginCredentialAPIKey(meta); token != "" { *apiKey = token } } diff --git a/bridges/ai/tools_search_fetch_test.go b/bridges/ai/tools_search_fetch_test.go index 7fcea4bd..a8006da4 100644 --- a/bridges/ai/tools_search_fetch_test.go +++ b/bridges/ai/tools_search_fetch_test.go @@ -10,8 +10,10 @@ func TestApplyLoginTokensToSearchConfig_MagicProxyForcesExa(t *testing.T) { oc := &OpenAIConnector{} meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "magic-token", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "magic-token", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } cfg := &search.Config{ Provider: search.ProviderExa, @@ -60,7 +62,9 @@ func TestApplyLoginTokensToSearchConfig_DefaultExaEndpointDoesNotForceExa(t *tes oc := &OpenAIConnector{} meta := &UserLoginMetadata{ Provider: ProviderOpenRouter, - APIKey: "openrouter-token", + Credentials: &LoginCredentials{ + APIKey: "openrouter-token", + }, } cfg := &search.Config{ Provider: search.ProviderExa, diff --git a/bridges/ai/tools_tts_test.go b/bridges/ai/tools_tts_test.go index 9def44c5..14ef62ad 100644 --- a/bridges/ai/tools_tts_test.go +++ b/bridges/ai/tools_tts_test.go @@ -25,7 +25,9 @@ func newTTSTestBridgeContext(meta *UserLoginMetadata, oc *OpenAIConnector) *Brid func TestResolveOpenAITTSBaseURLMagicProxy(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -42,7 +44,9 @@ func TestResolveOpenAITTSBaseURLMagicProxy(t *testing.T) { func TestResolveOpenAITTSBaseURLMagicProxyWithoutConnector(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - BaseURL: "https://bai.bt.hn/team/proxy/openrouter/v1", + Credentials: &LoginCredentials{ + BaseURL: "https://bai.bt.hn/team/proxy/openrouter/v1", + }, } btc := newTTSTestBridgeContext(meta, nil) diff --git a/pkg/connector/integrations_example-config.yaml b/pkg/connector/integrations_example-config.yaml deleted file mode 100644 index be28341e..00000000 --- a/pkg/connector/integrations_example-config.yaml +++ /dev/null @@ -1,364 +0,0 @@ -# Connector-specific configuration lives under the `network:` section of the -# main config file. - -# Beeper Cloud credentials for automatic login (optional). -# If user_mxid, base_url, and token are set, users don't need to manually log in. -beeper: - user_mxid: "" # Owning Matrix user for the built-in Beeper Cloud login. - base_url: "" # Optional. If empty, login uses selected Beeper domain. - token: "" # Beeper Matrix access token - -# Per-provider default models and settings. -# These are used when a room doesn't have a specific model configured. -providers: - beeper: - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - openai: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://api.openai.com/v1 - base_url: "https://api.openai.com/v1" - default_model: "openai/gpt-5.2" - openrouter: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://openrouter.ai/api/v1 - base_url: "https://openrouter.ai/api/v1" - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - -# Optional model catalog seeding. -# models: -# mode: "merge" # merge | replace -# providers: -# openai: -# models: -# - id: "gpt-5.2" -# name: "GPT-5.2" -# reasoning: true -# input: ["text", "image"] -# context_window: 128000 -# max_tokens: 8192 - -# Global settings -default_system_prompt: | - You are a helpful, concise assistant. - Ask clarifying questions when needed. - Follow the user's intent and be accurate. -model_cache_duration: 6h - -# Optional message rendering settings. -messages: - # History defaults for prompt construction. - # Set 0 to disable. - direct_chat: - history_limit: 20 - group_chat: - history_limit: 50 - # Queue behavior while the agent is busy. - queue: - # Modes: collect, followup, steer, steer-backlog, interrupt - mode: "collect" - # Debounce time before draining queued messages (ms). - debounce_ms: 1000 - # Maximum queued messages before drop policy applies. - cap: 20 - # Drop policy when cap is exceeded: summarize, old, new - drop: "summarize" - -# Command authorization settings. -commands: - # Optional allowlist for owner-only tools/commands (Matrix IDs, or "matrix:@user:server"). - owner_allow_from: [] - -# Tool approval gating. -tool_approvals: - enabled: true - ttl_seconds: 600 - require_for_mcp: true - # List of builtin tool names that require approval (subject to per-tool action allowlists). - # Note: `message` approvals apply to Desktop API routing too (e.g. action=send/reply/edit with desktop chat hints), - # while Desktop read-only actions like desktop-search-* do not require approval. - require_for_tools: ["message", "cron", "gravatar_set", "create_agent", "fork_agent", "edit_agent", "delete_agent", "modify_room", "sessions_send", "sessions_spawn", "run_internal_command"] - # Fallback when approval times out: "deny" (default) | "allow". - # Set to "allow" for cron/automated contexts where no human can respond. - -# Optional per-channel overrides. -channels: - matrix: - # Matrix reply/thread behavior. - reply_to_mode: "first" - -# Session configuration. -session: - # Scope for session state: per-sender (default) or global. - scope: "per-sender" - # Main session key alias (default: "main"). - main_key: "main" - -# External tool providers (search + fetch). Proxy is optional. -tools: - search: - provider: "openrouter" - fallbacks: ["exa", "brave", "perplexity"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - type: "auto" - num_results: 5 - include_text: false - text_max_chars: 500 - highlights: true # enabled by default; provides description snippets for source cards - brave: - api_key: "" - base_url: "https://api.search.brave.com/res/v1/web/search" - perplexity: - api_key: "" - base_url: "https://openrouter.ai/api/v1" - model: "perplexity/sonar-pro" - openrouter: - api_key: "" - base_url: "https://openrouter.ai/api/v1" - model: "openai/gpt-5.2" - fetch: - provider: "exa" - fallbacks: ["direct"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - include_text: true - text_max_chars: 5000 - direct: - enabled: true - timeout_seconds: 30 - max_chars: 50000 - max_redirects: 3 - - # Generic MCP behavior. - mcp: - # Disabled by default for safety. Enable explicitly to allow local stdio MCP servers. - enable_stdio: false - - # Virtual filesystem tools. - vfs: - apply_patch: - enabled: false - allow_models: [] - - # Media understanding/transcription. - # Supports provider/CLI entries and per-capability defaults. - media: - concurrency: 2 - image: - enabled: true - prompt: "Describe the image." - max_bytes: 10485760 - max_chars: 500 - timeout_seconds: 60 - models: - - provider: "openrouter" - model: "google/gemini-3-flash-preview" - audio: - enabled: true - prompt: "Transcribe the audio." - language: "" - # CLI transcription auto-detection (whisper/whisper.cpp) is not implemented yet. - max_bytes: 20971520 - timeout_seconds: 60 - models: - - provider: "openai" - model: "gpt-4o-mini-transcribe" - video: - enabled: true - prompt: "Describe the video." - max_bytes: 52428800 - timeout_seconds: 120 - models: - - provider: "openrouter" - model: "google/gemini-3-flash-preview" - chunking: - tokens: 400 - overlap: 80 - sync: - on_session_start: true - on_search: true - watch: true - watch_debounce_ms: 1500 - interval_minutes: 0 - sessions: - delta_bytes: 100000 - delta_messages: 50 - query: - max_results: 6 - min_score: 0.35 - hybrid: - candidate_multiplier: 4 - cache: - enabled: true - max_entries: -1 # Unlimited when cache.enabled is true; experimental.session_memory does not change cache size semantics. - experimental: - session_memory: false - -# Recall configuration. -# recall: -# citations: "auto" # auto | on | off -# inject_context: false # default false. when true, injects MEMORY.md snippets as extra system context. - - # Tool policy. Controls allow/deny lists and profiles. - # tool_policy: - # profile: "full" - # # group:openclaw is the strict OpenClaw native tool set. - # # group:agentremote includes agentremote-only extras (beeper_docs, gravatar_*, tts, image_generate, calculator, etc). - # allow: ["group:openclaw", "group:agentremote"] - # deny: [] - # subagents: - # tools: - # deny: ["sessions_list", "sessions_history", "sessions_send"] - - # Agent defaults. - # agents: - # defaults: - # subagents: - # model: "anthropic/claude-sonnet-4.5" - # allow_agents: ["*"] - # skip_bootstrap: false - # bootstrap_max_chars: 20000 - # typing_mode: "instant" # never|instant|thinking|message (message ignores NO_REPLY) - # typing_interval_seconds: 6 # refresh cadence, not start time (heartbeats never show typing) - # soul_evil: - # file: "SOUL_EVIL.md" - # chance: 0.1 - # purge: - # at: "21:00" - # duration: "15m" - -# Context pruning configuration. -# Reduces token usage by intelligently truncating old tool results. -pruning: - # Pruning mode: off | cache-ttl - # cache-ttl is the default pruning mode. - mode: "cache-ttl" - - # Refresh interval for cache-ttl mode. - ttl: "1h" - - # Enable proactive context pruning - enabled: true - - # Ratio of context window usage that triggers soft trimming (0.0-1.0) - # At 30% usage, large tool results start getting truncated - soft_trim_ratio: 0.3 - - # Ratio of context window usage that triggers hard clearing (0.0-1.0) - # At 50% usage, old tool results are replaced with placeholder - hard_clear_ratio: 0.5 - - # Number of recent assistant messages to protect from pruning - keep_last_assistants: 3 - - # Minimum total chars in prunable tool results before hard clear kicks in - min_prunable_chars: 50000 - - # Tool results larger than this are candidates for soft trimming - soft_trim_max_chars: 4000 - - # When soft trimming, keep this many chars from the start - soft_trim_head_chars: 1500 - - # When soft trimming, keep this many chars from the end - soft_trim_tail_chars: 1500 - - # Enable/disable hard clear phase - hard_clear_enabled: true - - # Placeholder text for hard-cleared tool results - hard_clear_placeholder: "[Old tool result content cleared]" - - # Tool patterns to allow/deny pruning (supports wildcards: list_*, *_search) - # Empty means all tools are prunable unless denied - # tools_allow: [] - # tools_deny: [] - - # --- LLM-based summarization (compaction) --- - # When enabled, uses an LLM to generate intelligent summaries of compacted - # content instead of just using placeholder text. This preserves context better. - - # Enable LLM summarization (default: true when pruning is enabled) - summarization_enabled: true - - # Model to use for generating summaries (default: fast model) - summarization_model: "openai/gpt-5.2" - - # Maximum tokens for generated summaries - max_summary_tokens: 500 - - # Compaction mode: - # - default: balanced reduction - # - safeguard: preserves recent context more aggressively - compaction_mode: "safeguard" - - # Minimum recent token budget preserved during safeguard compaction - keep_recent_tokens: 20000 - - # Maximum ratio of context that history can consume (0.0-1.0) - # When exceeded, oldest messages are summarized to fit budget - max_history_share: 0.5 - - # Token budget reserved for compaction output - reserve_tokens: 20000 - # Floor applied to reserve_tokens to avoid aggressive overfill - reserve_tokens_floor: 20000 - - # Optional post-compaction system context injected before retry - post_compaction_refresh_prompt: "[Post-compaction context refresh]\nRe-anchor to the latest user intent and preserve unresolved tasks and identifiers." - - # Additional instructions for the summarization model - # custom_instructions: "Focus on preserving code decisions and TODOs" - - # Identifier preservation policy for summaries: - # - strict (default): preserve opaque identifiers exactly - # - off: no special identifier-preservation instruction - # - custom: use identifier_instructions below - identifier_policy: "strict" - # identifier_instructions: "Keep ticket IDs, hashes, and hostnames unchanged." - - # Optional pre-compaction overflow flush turn. - # Enabled by default. Disable explicitly if you want no pre-flush. - overflow_flush: - enabled: true - soft_threshold_tokens: 4000 - prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." - system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." - -# Link preview configuration. -# Automatically fetches metadata for URLs in messages to provide context to the AI -# and generate rich previews in outgoing AI responses. -link_previews: - # Enable link preview functionality (default: true) - enabled: true - - # Maximum number of URLs to fetch from user messages for AI context (default: 3) - max_urls_inbound: 3 - - # Maximum number of URLs to preview in AI responses (default: 5) - max_urls_outbound: 5 - - # Timeout for fetching each URL (default: 10s) - fetch_timeout: 10s - - # Maximum characters from description to include in context (default: 500) - max_content_chars: 500 - - # Maximum page size to download in bytes (default: 10MB) - max_page_bytes: 10485760 - - # Maximum image size to download in bytes (default: 5MB) - max_image_bytes: 5242880 - - # How long to cache URL previews (default: 1h) - cache_ttl: 1h From cb875a763f37af44d6ecc251fe9ac7faaded93a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 22:37:40 +0200 Subject: [PATCH 09/23] Refactor AI bridge helpers and media calls Consolidate and clean up duplicated logic across the AI bridge. - Add emitHeartbeatFailure to centralize heartbeat failure event emission and replace repeated inlined blocks. - Introduce callGeminiMediaCapability to unify Google/Gemini media capability calls for image/video and remove duplicated request construction; remove media request type aliases. - Replace mergeLoginCredentials usage by cloning service tokens and setting credentials nil when empty to avoid in-place merges (uses cloneServiceTokens). - Remove legacy config upgrade machinery and deprecated mergeServiceTokens/mergeLoginCredentials functions and related configupgrade imports. These changes reduce duplication, simplify imports, and make media and heartbeat handling more maintainable. --- bridges/ai/agent_loop_request_builders.go | 4 +- bridges/ai/client.go | 44 ++--- bridges/ai/constructors.go | 14 +- bridges/ai/error_logging.go | 1 - bridges/ai/handlematrix.go | 8 +- bridges/ai/heartbeat_execute.go | 58 ++---- bridges/ai/integrations_config.go | 200 -------------------- bridges/ai/login.go | 8 +- bridges/ai/media_understanding_providers.go | 3 - bridges/ai/media_understanding_runner.go | 51 +++-- bridges/ai/messages.go | 15 +- bridges/ai/metadata.go | 80 -------- bridges/ai/pending_queue.go | 8 +- bridges/ai/prompt_builder.go | 28 ++- bridges/ai/prompt_context_ops.go | 8 +- bridges/ai/session_greeting.go | 14 -- bridges/ai/session_greeting_test.go | 22 +-- bridges/ai/streaming_continuation.go | 6 +- bridges/ai/streaming_input_conversion.go | 5 - 19 files changed, 106 insertions(+), 471 deletions(-) diff --git a/bridges/ai/agent_loop_request_builders.go b/bridges/ai/agent_loop_request_builders.go index a6e2be3a..3ad189b7 100644 --- a/bridges/ai/agent_loop_request_builders.go +++ b/bridges/ai/agent_loop_request_builders.go @@ -119,8 +119,8 @@ func (oc *AIClient) buildResponsesAgentLoopParams( if settings.temperature != nil { params.Temperature = openai.Float(*settings.temperature) } - if strings.TrimSpace(systemPrompt) != "" { - params.Instructions = openai.String(strings.TrimSpace(systemPrompt)) + if trimmed := strings.TrimSpace(systemPrompt); trimmed != "" { + params.Instructions = openai.String(trimmed) } if effort, ok := reasoningEffortMap[settings.reasoningEffort]; ok { params.Reasoning = shared.ReasoningParam{ diff --git a/bridges/ai/client.go b/bridges/ai/client.go index 4092881c..251536f7 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -1155,15 +1155,9 @@ func (oc *AIClient) defaultModelForProvider() string { } switch loginMeta.Provider { case ProviderOpenAI: - if configured := strings.TrimSpace(oc.defaultModelSelection(ProviderOpenAI).Primary); configured != "" { - return configured - } - return DefaultModelOpenAI + return oc.defaultModelSelection(ProviderOpenAI).Primary case ProviderOpenRouter, ProviderMagicProxy: - if configured := strings.TrimSpace(oc.defaultModelSelection(ProviderOpenRouter).Primary); configured != "" { - return configured - } - return DefaultModelOpenRouter + return oc.defaultModelSelection(ProviderOpenRouter).Primary default: return DefaultModelOpenRouter } @@ -1171,19 +1165,24 @@ func (oc *AIClient) defaultModelForProvider() string { func (oc *AIClient) defaultModelSelection(provider string) ModelSelectionConfig { if oc == nil || oc.connector == nil || oc.connector.Config.Agents == nil || oc.connector.Config.Agents.Defaults == nil || oc.connector.Config.Agents.Defaults.Model == nil { - return ModelSelectionConfig{} + return ModelSelectionConfig{Primary: defaultModelForProviderName(provider)} } selection := *oc.connector.Config.Agents.Defaults.Model - if strings.TrimSpace(selection.Primary) != "" { - return selection + if strings.TrimSpace(selection.Primary) == "" { + selection.Primary = defaultModelForProviderName(provider) } + return selection +} + +func defaultModelForProviderName(provider string) string { switch strings.ToLower(strings.TrimSpace(provider)) { case ProviderOpenAI: - selection.Primary = DefaultModelOpenAI + return DefaultModelOpenAI case ProviderOpenRouter, ProviderMagicProxy: - selection.Primary = DefaultModelOpenRouter + return DefaultModelOpenRouter + default: + return DefaultModelOpenRouter } - return selection } // effectivePrompt returns the base system prompt to use for non-agent rooms. @@ -1511,11 +1510,8 @@ func (oc *AIClient) isGroupChat(ctx context.Context, portal *bridgev2.Portal) bo } func (oc *AIClient) defaultPDFEngine() string { - if oc != nil && oc.connector != nil && oc.connector.Config.Agents != nil && - oc.connector.Config.Agents.Defaults != nil { - if engine := strings.TrimSpace(oc.connector.Config.Agents.Defaults.PDFEngine); engine != "" { - return engine - } + if oc != nil && oc.connector != nil { + return oc.connector.defaultPDFEngineForInit() } return "mistral-ocr" } @@ -1841,7 +1837,7 @@ func (oc *AIClient) buildMediaTurnContext( eventID id.EventID, ) (PromptContext, error) { return oc.buildPromptContextForTurn(ctx, portal, meta, caption, eventID, currentTurnPromptOptions{ - includeLinkScope: true, + currentTurnTextOptions: currentTurnTextOptions{includeLinkScope: true}, attachment: &turnAttachmentOptions{ mediaURL: mediaURL, mimeType: mimeType, @@ -1872,13 +1868,7 @@ func (oc *AIClient) buildContextUpToMessage( base.Messages = append(base.Messages, historyMessages...) body := strings.TrimSpace(newBody) body = airuntime.SanitizeChatMessageForDisplay(body, true) - base.Messages = append(base.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: body, - }}, - }) + base.Messages = append(base.Messages, newUserTextPromptMessage(body)) return base, nil } diff --git a/bridges/ai/constructors.go b/bridges/ai/constructors.go index 2be3a0f4..dbc31714 100644 --- a/bridges/ai/constructors.go +++ b/bridges/ai/constructors.go @@ -3,7 +3,6 @@ package ai import ( "context" - "go.mau.fi/util/configupgrade" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/commands" @@ -61,13 +60,12 @@ func NewAIConnector() *OpenAIConnector { DefaultCommandPrefix: func() string { return bridgesdk.ResolveCommandPrefix(oc.Config.Bridge.CommandPrefix, "!ai") }, - ExampleConfig: exampleNetworkConfig, - ConfigData: &oc.Config, - ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + ExampleConfig: exampleNetworkConfig, + ConfigData: &oc.Config, + NewPortal: func() any { return &PortalMetadata{} }, + NewMessage: func() any { return &MessageMetadata{} }, + NewLogin: func() any { return &UserLoginMetadata{} }, + NewGhost: func() any { return &GhostMetadata{} }, FillBridgeInfo: func(portal *bridgev2.Portal, content *event.BridgeEventContent) { applyAgentRemoteBridgeInfo(portal, portalMeta(portal), content) }, diff --git a/bridges/ai/error_logging.go b/bridges/ai/error_logging.go index 6307918e..13591dc5 100644 --- a/bridges/ai/error_logging.go +++ b/bridges/ai/error_logging.go @@ -55,7 +55,6 @@ func addRequestSummary(event *zerolog.Event, metadata *PortalMetadata, prompt Pr } } event.Int("message_count", len(prompt.Messages)) - event.Bool("has_audio", promptHasAudioContent(prompt)) event.Bool("has_multimodal", promptHasMultimodalContent(prompt)) } diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index bc4a9a4c..57f2fa6b 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -1086,12 +1086,6 @@ func (oc *AIClient) buildContextForRegenerate( return PromptContext{}, err } base.Messages = append(base.Messages, historyMessages...) - base.Messages = append(base.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: latestUserBody, - }}, - }) + base.Messages = append(base.Messages, newUserTextPromptMessage(latestUserBody)) return base, nil } diff --git a/bridges/ai/heartbeat_execute.go b/bridges/ai/heartbeat_execute.go index ae13e85f..357fc908 100644 --- a/bridges/ai/heartbeat_execute.go +++ b/bridges/ai/heartbeat_execute.go @@ -182,19 +182,7 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, promptContext, err := oc.buildHeartbeatTurnContext(context.Background(), sessionPortal, promptMeta, prompt) if err != nil { oc.log.Warn().Str("agent_id", agentID).Str("reason", reason).Err(err).Msg("Heartbeat failed to build prompt") - indicator := (*HeartbeatIndicatorType)(nil) - if hbCfg.UseIndicator { - indicator = resolveIndicatorType("failed") - } - oc.emitHeartbeatEvent(&HeartbeatEventPayload{ - TS: time.Now().UnixMilli(), - Status: "failed", - Reason: err.Error(), - Channel: hbCfg.Channel, - To: hbCfg.TargetRoom.String(), - DurationMs: time.Now().UnixMilli() - startedAtMs, - IndicatorType: indicator, - }) + oc.emitHeartbeatFailure(hbCfg, startedAtMs, err.Error()) return heartbeatRunResult{Status: "failed", Reason: err.Error()} } @@ -228,39 +216,31 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, return heartbeatRunResult{Status: res.Status, Reason: res.Reason} case <-done: oc.log.Warn().Str("agent_id", agentID).Msg("Heartbeat failed: stream completed without outcome") - indicator := (*HeartbeatIndicatorType)(nil) - if hbCfg.UseIndicator { - indicator = resolveIndicatorType("failed") - } - oc.emitHeartbeatEvent(&HeartbeatEventPayload{ - TS: time.Now().UnixMilli(), - Status: "failed", - Reason: "stream-finished-without-outcome", - Channel: hbCfg.Channel, - To: hbCfg.TargetRoom.String(), - DurationMs: time.Now().UnixMilli() - startedAtMs, - IndicatorType: indicator, - }) + oc.emitHeartbeatFailure(hbCfg, startedAtMs, "stream-finished-without-outcome") return heartbeatRunResult{Status: "failed", Reason: "heartbeat failed"} case <-timeoutCtx.Done(): oc.log.Warn().Str("agent_id", agentID).Msg("Heartbeat timed out after 2 minutes") - indicator := (*HeartbeatIndicatorType)(nil) - if hbCfg.UseIndicator { - indicator = resolveIndicatorType("failed") - } - oc.emitHeartbeatEvent(&HeartbeatEventPayload{ - TS: time.Now().UnixMilli(), - Status: "failed", - Reason: "timeout", - Channel: hbCfg.Channel, - To: hbCfg.TargetRoom.String(), - DurationMs: time.Now().UnixMilli() - startedAtMs, - IndicatorType: indicator, - }) + oc.emitHeartbeatFailure(hbCfg, startedAtMs, "timeout") return heartbeatRunResult{Status: "failed", Reason: "heartbeat timed out"} } } +func (oc *AIClient) emitHeartbeatFailure(hbCfg *HeartbeatRunConfig, startedAtMs int64, reason string) { + indicator := (*HeartbeatIndicatorType)(nil) + if hbCfg.UseIndicator { + indicator = resolveIndicatorType("failed") + } + oc.emitHeartbeatEvent(&HeartbeatEventPayload{ + TS: time.Now().UnixMilli(), + Status: "failed", + Reason: reason, + Channel: hbCfg.Channel, + To: hbCfg.TargetRoom.String(), + DurationMs: time.Now().UnixMilli() - startedAtMs, + IndicatorType: indicator, + }) +} + func drainHeartbeatSystemEvents(ownerKey string, primaryKey string, secondaryKey string) []SystemEvent { entries := drainSystemEventEntries(ownerKey, primaryKey) if sk := strings.TrimSpace(secondaryKey); sk != "" && !strings.EqualFold(strings.TrimSpace(primaryKey), sk) { diff --git a/bridges/ai/integrations_config.go b/bridges/ai/integrations_config.go index db359427..fad865eb 100644 --- a/bridges/ai/integrations_config.go +++ b/bridges/ai/integrations_config.go @@ -5,7 +5,6 @@ import ( "strings" "time" - "go.mau.fi/util/configupgrade" "go.mau.fi/util/ptr" "github.com/beeper/agentremote/pkg/agents" @@ -464,202 +463,3 @@ type ModelDefinitionConfig struct { // BridgeConfig is an alias for the shared bridge config. type BridgeConfig = bridgeconfig.BridgeConfig - -func upgradeConfig(helper configupgrade.Helper) { - // Beeper credentials for auto-login - helper.Copy(configupgrade.Str, "beeper", "user_mxid") - helper.Copy(configupgrade.Str, "beeper", "base_url") - helper.Copy(configupgrade.Str, "beeper", "token") - - // Model providers and defaults - helper.Copy(configupgrade.Str, "models", "mode") - helper.Copy(configupgrade.Map, "models", "providers") - helper.Copy(configupgrade.Str, "agents", "defaults", "model", "primary") - helper.Copy(configupgrade.List, "agents", "defaults", "model", "fallbacks") - helper.Copy(configupgrade.Str, "agents", "defaults", "image_model", "primary") - helper.Copy(configupgrade.List, "agents", "defaults", "image_model", "fallbacks") - helper.Copy(configupgrade.Str, "agents", "defaults", "image_generation_model", "primary") - helper.Copy(configupgrade.List, "agents", "defaults", "image_generation_model", "fallbacks") - helper.Copy(configupgrade.Str, "agents", "defaults", "pdf_model", "primary") - helper.Copy(configupgrade.List, "agents", "defaults", "pdf_model", "fallbacks") - helper.Copy(configupgrade.Str, "agents", "defaults", "pdf_engine") - - // Global settings - helper.Copy(configupgrade.Str, "default_system_prompt") - helper.Copy(configupgrade.Str, "model_cache_duration") - helper.Copy(configupgrade.Str, "memory", "citations") - helper.Copy(configupgrade.Bool, "memory", "inject_context") - - // Tool approvals - helper.Copy(configupgrade.Bool, "tool_approvals", "enabled") - helper.Copy(configupgrade.Int, "tool_approvals", "ttl_seconds") - helper.Copy(configupgrade.Bool, "tool_approvals", "require_for_mcp") - helper.Copy(configupgrade.List, "tool_approvals", "require_for_tools") - - // Bridge-specific configuration - helper.Copy(configupgrade.Str, "bridge", "command_prefix") - - // Compaction configuration - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "mode") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "ttl") - helper.Copy(configupgrade.Bool, "agents", "defaults", "compaction", "enabled") - helper.Copy(configupgrade.Float, "agents", "defaults", "compaction", "soft_trim_ratio") - helper.Copy(configupgrade.Float, "agents", "defaults", "compaction", "hard_clear_ratio") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "keep_last_assistants") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "min_prunable_chars") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "soft_trim_max_chars") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "soft_trim_head_chars") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "soft_trim_tail_chars") - helper.Copy(configupgrade.Bool, "agents", "defaults", "compaction", "hard_clear_enabled") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "hard_clear_placeholder") - helper.Copy(configupgrade.Bool, "agents", "defaults", "compaction", "summarization_enabled") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "summarization_model") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "max_summary_tokens") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "compaction_mode") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "keep_recent_tokens") - helper.Copy(configupgrade.Float, "agents", "defaults", "compaction", "max_history_share") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "reserve_tokens") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "reserve_tokens_floor") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "custom_instructions") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "identifier_policy") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "identifier_instructions") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "post_compaction_refresh_prompt") - helper.Copy(configupgrade.Bool, "agents", "defaults", "compaction", "overflow_flush", "enabled") - helper.Copy(configupgrade.Int, "agents", "defaults", "compaction", "overflow_flush", "soft_threshold_tokens") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "overflow_flush", "prompt") - helper.Copy(configupgrade.Str, "agents", "defaults", "compaction", "overflow_flush", "system_prompt") - - // Inbound message processing configuration - helper.Copy(configupgrade.Str, "inbound", "dedupe_ttl") - helper.Copy(configupgrade.Int, "inbound", "dedupe_max_size") - helper.Copy(configupgrade.Int, "inbound", "default_debounce_ms") - - // Cron configuration - helper.Copy(configupgrade.Bool, "cron", "enabled") - - // Messages configuration - helper.Copy(configupgrade.Str, "messages", "ack_reaction") - helper.Copy(configupgrade.Str, "messages", "ack_reaction_scope") - helper.Copy(configupgrade.Bool, "messages", "remove_ack_after") - helper.Copy(configupgrade.Int, "messages", "group_chat", "history_limit") - helper.Copy(configupgrade.List, "messages", "group_chat", "mention_patterns") - helper.Copy(configupgrade.Str, "messages", "group_chat", "activation") - helper.Copy(configupgrade.Int, "messages", "direct_chat", "history_limit") - helper.Copy(configupgrade.Int, "messages", "inbound", "debounce_ms") - helper.Copy(configupgrade.Map, "messages", "inbound", "by_channel") - helper.Copy(configupgrade.List, "commands", "owner_allow_from") - helper.Copy(configupgrade.Str, "messages", "queue", "mode") - helper.Copy(configupgrade.Map, "messages", "queue", "by_channel") - helper.Copy(configupgrade.Int, "messages", "queue", "debounce_ms") - helper.Copy(configupgrade.Map, "messages", "queue", "debounce_ms_by_channel") - helper.Copy(configupgrade.Int, "messages", "queue", "cap") - helper.Copy(configupgrade.Str, "messages", "queue", "drop") - - // Session configuration - helper.Copy(configupgrade.Str, "session", "scope") - helper.Copy(configupgrade.Str, "session", "main_key") - - // Agents heartbeat configuration - helper.Copy(configupgrade.Int, "agents", "defaults", "timeout_seconds") - helper.Copy(configupgrade.Str, "agents", "defaults", "user_timezone") - helper.Copy(configupgrade.Str, "agents", "defaults", "envelope_timezone") - helper.Copy(configupgrade.Str, "agents", "defaults", "envelope_timestamp") - helper.Copy(configupgrade.Str, "agents", "defaults", "envelope_elapsed") - helper.Copy(configupgrade.Str, "agents", "defaults", "typing_mode") - helper.Copy(configupgrade.Int, "agents", "defaults", "typing_interval_seconds") - helper.Copy(configupgrade.Map, "agents", "defaults", "subagents") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "every") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "prompt") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "model") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "session") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "target") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "to") - helper.Copy(configupgrade.Int, "agents", "defaults", "heartbeat", "ack_max_chars") - helper.Copy(configupgrade.Bool, "agents", "defaults", "heartbeat", "include_reasoning") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "active_hours", "start") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "active_hours", "end") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "active_hours", "timezone") - helper.Copy(configupgrade.List, "agents", "list") - - // Channels heartbeat visibility - helper.Copy(configupgrade.Bool, "channels", "defaults", "heartbeat", "show_ok") - helper.Copy(configupgrade.Bool, "channels", "defaults", "heartbeat", "show_alerts") - helper.Copy(configupgrade.Bool, "channels", "defaults", "heartbeat", "use_indicator") - helper.Copy(configupgrade.Bool, "channels", "matrix", "heartbeat", "show_ok") - helper.Copy(configupgrade.Bool, "channels", "matrix", "heartbeat", "show_alerts") - helper.Copy(configupgrade.Bool, "channels", "matrix", "heartbeat", "use_indicator") - helper.Copy(configupgrade.Str, "channels", "matrix", "reply_to_mode") - helper.Copy(configupgrade.Str, "channels", "matrix", "thread_replies") - - // Tools (web + links) - helper.Copy(configupgrade.Str, "tools", "web", "search", "provider") - helper.Copy(configupgrade.List, "tools", "web", "search", "fallbacks") - helper.Copy(configupgrade.Bool, "tools", "web", "search", "exa", "enabled") - helper.Copy(configupgrade.Str, "tools", "web", "search", "exa", "base_url") - helper.Copy(configupgrade.Str, "tools", "web", "search", "exa", "api_key") - helper.Copy(configupgrade.Str, "tools", "web", "search", "exa", "type") - helper.Copy(configupgrade.Str, "tools", "web", "search", "exa", "category") - helper.Copy(configupgrade.Int, "tools", "web", "search", "exa", "num_results") - helper.Copy(configupgrade.Bool, "tools", "web", "search", "exa", "include_text") - helper.Copy(configupgrade.Int, "tools", "web", "search", "exa", "text_max_chars") - helper.Copy(configupgrade.Bool, "tools", "web", "search", "exa", "highlights") - helper.Copy(configupgrade.Str, "tools", "web", "fetch", "provider") - helper.Copy(configupgrade.List, "tools", "web", "fetch", "fallbacks") - helper.Copy(configupgrade.Bool, "tools", "web", "fetch", "exa", "enabled") - helper.Copy(configupgrade.Str, "tools", "web", "fetch", "exa", "base_url") - helper.Copy(configupgrade.Str, "tools", "web", "fetch", "exa", "api_key") - helper.Copy(configupgrade.Bool, "tools", "web", "fetch", "exa", "include_text") - helper.Copy(configupgrade.Int, "tools", "web", "fetch", "exa", "text_max_chars") - helper.Copy(configupgrade.Bool, "tools", "web", "fetch", "direct", "enabled") - helper.Copy(configupgrade.Int, "tools", "web", "fetch", "direct", "timeout_seconds") - helper.Copy(configupgrade.Str, "tools", "web", "fetch", "direct", "user_agent") - helper.Copy(configupgrade.Bool, "tools", "web", "fetch", "direct", "readability") - helper.Copy(configupgrade.Int, "tools", "web", "fetch", "direct", "max_chars") - helper.Copy(configupgrade.Int, "tools", "web", "fetch", "direct", "max_redirects") - helper.Copy(configupgrade.Int, "tools", "web", "fetch", "direct", "cache_ttl_seconds") - helper.Copy(configupgrade.Bool, "tools", "links", "enabled") - helper.Copy(configupgrade.Int, "tools", "links", "max_urls_inbound") - helper.Copy(configupgrade.Int, "tools", "links", "max_urls_outbound") - helper.Copy(configupgrade.Str, "tools", "links", "fetch_timeout") - helper.Copy(configupgrade.Int, "tools", "links", "max_content_chars") - helper.Copy(configupgrade.Int, "tools", "links", "max_page_bytes") - helper.Copy(configupgrade.Int, "tools", "links", "max_image_bytes") - helper.Copy(configupgrade.Str, "tools", "links", "cache_ttl") - helper.Copy(configupgrade.Bool, "tools", "mcp", "enable_stdio") - helper.Copy(configupgrade.Int, "tools", "media", "image", "max_bytes") - helper.Copy(configupgrade.Int, "tools", "media", "image", "max_chars") - helper.Copy(configupgrade.Int, "tools", "media", "image", "timeout_seconds") - helper.Copy(configupgrade.Int, "tools", "media", "audio", "max_bytes") - helper.Copy(configupgrade.Int, "tools", "media", "audio", "timeout_seconds") - helper.Copy(configupgrade.Int, "tools", "media", "video", "max_bytes") - helper.Copy(configupgrade.Int, "tools", "media", "video", "timeout_seconds") - - // Memory search configuration - helper.Copy(configupgrade.Bool, "memory_search", "enabled") - helper.Copy(configupgrade.List, "memory_search", "sources") - helper.Copy(configupgrade.List, "memory_search", "extra_paths") - helper.Copy(configupgrade.Str, "memory_search", "store", "driver") - helper.Copy(configupgrade.Str, "memory_search", "store", "path") - helper.Copy(configupgrade.Int, "memory_search", "chunking", "tokens") - helper.Copy(configupgrade.Int, "memory_search", "chunking", "overlap") - helper.Copy(configupgrade.Bool, "memory_search", "sync", "on_session_start") - helper.Copy(configupgrade.Bool, "memory_search", "sync", "on_search") - helper.Copy(configupgrade.Bool, "memory_search", "sync", "watch") - helper.Copy(configupgrade.Int, "memory_search", "sync", "watch_debounce_ms") - helper.Copy(configupgrade.Int, "memory_search", "sync", "interval_minutes") - helper.Copy(configupgrade.Int, "memory_search", "sync", "sessions", "delta_bytes") - helper.Copy(configupgrade.Int, "memory_search", "sync", "sessions", "delta_messages") - helper.Copy(configupgrade.Int, "memory_search", "query", "max_results") - helper.Copy(configupgrade.Float, "memory_search", "query", "min_score") - helper.Copy(configupgrade.Int, "memory_search", "query", "hybrid", "candidate_multiplier") - helper.Copy(configupgrade.Bool, "memory_search", "cache", "enabled") - helper.Copy(configupgrade.Int, "memory_search", "cache", "max_entries") - helper.Copy(configupgrade.Bool, "memory_search", "experimental", "session_memory") - - // Tool policy - helper.Copy(configupgrade.Str, "tool_policy", "profile") - helper.Copy(configupgrade.List, "tool_policy", "allow") - helper.Copy(configupgrade.List, "tool_policy", "also_allow") - helper.Copy(configupgrade.List, "tool_policy", "deny") - helper.Copy(configupgrade.Map, "tool_policy", "by_provider") -} diff --git a/bridges/ai/login.go b/bridges/ai/login.go index 14abc785..4140e491 100644 --- a/bridges/ai/login.go +++ b/bridges/ai/login.go @@ -225,9 +225,13 @@ func (ol *OpenAILogin) finishLogin(ctx context.Context, provider, apiKey, baseUR BaseURL: baseURL, } if serviceTokens != nil && !serviceTokensEmpty(serviceTokens) { - creds.ServiceTokens = serviceTokens + creds.ServiceTokens = cloneServiceTokens(serviceTokens) + } + if loginCredentialsEmpty(creds) { + meta.Credentials = nil + } else { + meta.Credentials = creds } - meta.Credentials = mergeLoginCredentials(meta.Credentials, creds) if err := ol.validateLoginMetadata(ctx, loginID, meta); err != nil { return nil, err } diff --git a/bridges/ai/media_understanding_providers.go b/bridges/ai/media_understanding_providers.go index 513d34f1..57b7f5e0 100644 --- a/bridges/ai/media_understanding_providers.go +++ b/bridges/ai/media_understanding_providers.go @@ -400,9 +400,6 @@ type mediaAudioRequest struct { FileName string } -type mediaVideoRequest = mediaRequestBase -type mediaImageRequest = mediaRequestBase - func resolveMediaFileName(fallback string, msgType string, mediaURL string) string { base := strings.TrimSpace(fallback) if base != "" { diff --git a/bridges/ai/media_understanding_runner.go b/bridges/ai/media_understanding_runner.go index fbce68a5..f4d4ca4f 100644 --- a/bridges/ai/media_understanding_runner.go +++ b/bridges/ai/media_understanding_runner.go @@ -670,27 +670,8 @@ func (oc *AIClient) describeImageWithEntry( if actualMime == "" { actualMime = mimeType } - headers := mergeMediaHeaders(capCfg, entry) - apiKey := oc.resolveMediaProviderAPIKey("google", entry.Profile, entry.PreferredProfile) - if apiKey == "" && !hasProviderAuthHeader("google", headers) { - return nil, errors.New("missing API key for google image understanding") - } - request := mediaImageRequest{ - APIKey: apiKey, - BaseURL: resolveMediaBaseURL(capCfg, entry), - Headers: headers, - Model: strings.TrimSpace(entry.Model), - Prompt: prompt, - MimeType: actualMime, - Data: data, - Timeout: resolveMediaTimeoutSeconds(entry.TimeoutSeconds, capCfg, defaultTimeoutSecondsByCapability[MediaCapabilityImage]), - } - text, err := callGeminiForCapability(ctx, request, MediaCapabilityImage) - if err != nil { - return nil, err - } - text = truncateText(text, maxChars) - return buildMediaOutput(MediaCapabilityImage, text, "google", entry.Model, attachmentIndex), nil + timeout := resolveMediaTimeoutSeconds(entry.TimeoutSeconds, capCfg, defaultTimeoutSecondsByCapability[MediaCapabilityImage]) + return oc.callGeminiMediaCapability(ctx, MediaCapabilityImage, entry, capCfg, data, actualMime, prompt, timeout, maxChars, attachmentIndex) } rawData, actualMime, err := oc.downloadMediaBytes(ctx, mediaURL, encryptedFile, maxBytes, mimeType) @@ -852,13 +833,27 @@ func (oc *AIClient) describeVideoWithEntry( return nil, fmt.Errorf("unsupported video provider: %s", providerID) } + return oc.callGeminiMediaCapability(ctx, MediaCapabilityVideo, entry, capCfg, data, actualMime, prompt, timeout, maxChars, attachmentIndex) +} + +func (oc *AIClient) callGeminiMediaCapability( + ctx context.Context, + capability MediaUnderstandingCapability, + entry MediaUnderstandingModelConfig, + capCfg *MediaUnderstandingConfig, + data []byte, + actualMime string, + prompt string, + timeout time.Duration, + maxChars int, + attachmentIndex int, +) (*MediaUnderstandingOutput, error) { headers := mergeMediaHeaders(capCfg, entry) - apiKey := oc.resolveMediaProviderAPIKey(providerID, entry.Profile, entry.PreferredProfile) - if apiKey == "" && !hasProviderAuthHeader(providerID, headers) { - return nil, fmt.Errorf("missing API key for %s video description", providerID) + apiKey := oc.resolveMediaProviderAPIKey("google", entry.Profile, entry.PreferredProfile) + if apiKey == "" && !hasProviderAuthHeader("google", headers) { + return nil, fmt.Errorf("missing API key for google %s", capability) } - - request := mediaVideoRequest{ + request := mediaRequestBase{ APIKey: apiKey, BaseURL: resolveMediaBaseURL(capCfg, entry), Headers: headers, @@ -868,13 +863,13 @@ func (oc *AIClient) describeVideoWithEntry( Data: data, Timeout: timeout, } - text, err := callGeminiForCapability(ctx, request, MediaCapabilityVideo) + text, err := callGeminiForCapability(ctx, request, capability) if err != nil { return nil, err } text = strings.TrimSpace(text) text = truncateText(text, maxChars) - return buildMediaOutput(MediaCapabilityVideo, text, providerID, entry.Model, attachmentIndex), nil + return buildMediaOutput(capability, text, "google", entry.Model, attachmentIndex), nil } func (oc *AIClient) generateWithOpenRouter( diff --git a/bridges/ai/messages.go b/bridges/ai/messages.go index 1efd3ce9..fd165079 100644 --- a/bridges/ai/messages.go +++ b/bridges/ai/messages.go @@ -1,5 +1,7 @@ package ai +import "strings" + type PromptRole string const ( @@ -40,18 +42,19 @@ type PromptMessage struct { } func (m PromptMessage) Text() string { - var text string + var sb strings.Builder for _, block := range m.Blocks { switch block.Type { case PromptBlockText, PromptBlockThinking: - if text == "" { - text = block.Text - } else if block.Text != "" { - text += "\n" + block.Text + if block.Text != "" { + if sb.Len() > 0 { + sb.WriteByte('\n') + } + sb.WriteString(block.Text) } } } - return text + return sb.String() } // PromptContext is the bridge-local prompt envelope used throughout bridges/ai. diff --git a/bridges/ai/metadata.go b/bridges/ai/metadata.go index 149e6995..8b54aa4a 100644 --- a/bridges/ai/metadata.go +++ b/bridges/ai/metadata.go @@ -211,33 +211,6 @@ func cloneServiceTokens(src *ServiceTokens) *ServiceTokens { return &clone } -func mergeLoginCredentials(existing, incoming *LoginCredentials) *LoginCredentials { - if incoming == nil { - return existing - } - if existing == nil { - clone := *incoming - clone.ServiceTokens = cloneServiceTokens(incoming.ServiceTokens) - if loginCredentialsEmpty(&clone) { - return nil - } - return &clone - } - - merged := *existing - if strings.TrimSpace(incoming.APIKey) != "" { - merged.APIKey = incoming.APIKey - } - if strings.TrimSpace(incoming.BaseURL) != "" { - merged.BaseURL = incoming.BaseURL - } - merged.ServiceTokens = mergeServiceTokens(existing.ServiceTokens, incoming.ServiceTokens) - if loginCredentialsEmpty(&merged) { - return nil - } - return &merged -} - func serviceTokensEmpty(tokens *ServiceTokens) bool { if tokens == nil { return true @@ -353,59 +326,6 @@ func cloneUserLoginMetadata(src *UserLoginMetadata) (*UserLoginMetadata, error) return &clone, nil } -func mergeServiceTokens(existing, incoming *ServiceTokens) *ServiceTokens { - if incoming == nil { - return existing - } - if existing == nil { - clone := *incoming - if incoming.DesktopAPIInstances != nil { - clone.DesktopAPIInstances = maps.Clone(incoming.DesktopAPIInstances) - } - if incoming.MCPServers != nil { - clone.MCPServers = maps.Clone(incoming.MCPServers) - } - return &clone - } - - merged := *existing - if incoming.OpenAI != "" { - merged.OpenAI = incoming.OpenAI - } - if incoming.OpenRouter != "" { - merged.OpenRouter = incoming.OpenRouter - } - if incoming.Exa != "" { - merged.Exa = incoming.Exa - } - if incoming.Brave != "" { - merged.Brave = incoming.Brave - } - if incoming.Perplexity != "" { - merged.Perplexity = incoming.Perplexity - } - if incoming.DesktopAPI != "" { - merged.DesktopAPI = incoming.DesktopAPI - } - if len(incoming.DesktopAPIInstances) > 0 { - if merged.DesktopAPIInstances == nil { - merged.DesktopAPIInstances = make(map[string]DesktopAPIInstance, len(incoming.DesktopAPIInstances)) - } - for key, value := range incoming.DesktopAPIInstances { - merged.DesktopAPIInstances[key] = value - } - } - if len(incoming.MCPServers) > 0 { - if merged.MCPServers == nil { - merged.MCPServers = make(map[string]MCPServerConfig, len(incoming.MCPServers)) - } - for key, value := range incoming.MCPServers { - merged.MCPServers[key] = value - } - } - return &merged -} - func agentsEnabled(meta *UserLoginMetadata) bool { if meta == nil || meta.Agents == nil { return false diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index cee46466..d2b40089 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -333,13 +333,7 @@ func buildSteeringPromptMessages(prompts []string) []PromptMessage { if prompt == "" { continue } - messages = append(messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: prompt, - }}, - }) + messages = append(messages, newUserTextPromptMessage(prompt)) } return messages } diff --git a/bridges/ai/prompt_builder.go b/bridges/ai/prompt_builder.go index 91d7ab13..3ce85c48 100644 --- a/bridges/ai/prompt_builder.go +++ b/bridges/ai/prompt_builder.go @@ -40,12 +40,9 @@ type turnAttachmentOptions struct { } type currentTurnPromptOptions struct { - rawEventContent map[string]any - includeLinkScope bool - prepend []string - append []string - leadingBlocks []PromptBlock - attachment *turnAttachmentOptions + currentTurnTextOptions + leadingBlocks []PromptBlock + attachment *turnAttachmentOptions } func joinPromptFragments(parts ...string) string { @@ -130,11 +127,11 @@ func (oc *AIClient) replayHistoryMessages( skipAssistantID := networkid.MessageID("") if opts.mode == historyReplayRegen { for _, candidate := range candidates { - if skipUserID == "" && candidate.meta != nil && candidate.meta.Role == "user" { + if skipUserID == "" && candidate.meta != nil && candidate.meta.Role == string(PromptRoleUser) { skipUserID = candidate.row.ID continue } - if skipAssistantID == "" && candidate.meta != nil && candidate.meta.Role == "assistant" { + if skipAssistantID == "" && candidate.meta != nil && candidate.meta.Role == string(PromptRoleAssistant) { skipAssistantID = candidate.row.ID } if skipUserID != "" && skipAssistantID != "" { @@ -215,12 +212,9 @@ func (oc *AIClient) buildPromptContextForTurn( appendFragments = append(appendFragments, attachmentAppend...) } - base, text, err := oc.buildCurrentTurnText(ctx, portal, meta, userText, eventID, currentTurnTextOptions{ - rawEventContent: opts.rawEventContent, - includeLinkScope: opts.includeLinkScope, - prepend: opts.prepend, - append: appendFragments, - }) + textOpts := opts.currentTurnTextOptions + textOpts.append = appendFragments + base, text, err := oc.buildCurrentTurnText(ctx, portal, meta, userText, eventID, textOpts) if err != nil { return PromptContext{}, err } @@ -274,8 +268,10 @@ func (oc *AIClient) buildCurrentTurnWithLinks( eventID id.EventID, ) (PromptContext, error) { return oc.buildPromptContextForTurn(ctx, portal, meta, userText, eventID, currentTurnPromptOptions{ - rawEventContent: rawEventContent, - includeLinkScope: true, + currentTurnTextOptions: currentTurnTextOptions{ + rawEventContent: rawEventContent, + includeLinkScope: true, + }, }) } diff --git a/bridges/ai/prompt_context_ops.go b/bridges/ai/prompt_context_ops.go index dca40eb1..2dcff3e3 100644 --- a/bridges/ai/prompt_context_ops.go +++ b/bridges/ai/prompt_context_ops.go @@ -26,12 +26,6 @@ func ClonePromptContext(ctx PromptContext) PromptContext { return cloned } -func AppendPromptMessages(ctx *PromptContext, messages ...PromptMessage) { - if ctx == nil || len(messages) == 0 { - return - } - ctx.Messages = append(ctx.Messages, ClonePromptMessages(messages)...) -} func PromptContextMessageCount(ctx PromptContext) int { count := len(ctx.Messages) @@ -41,7 +35,7 @@ func PromptContextMessageCount(ctx PromptContext) int { return count } -func NewUserTextPromptMessage(text string) PromptMessage { +func newUserTextPromptMessage(text string) PromptMessage { return PromptMessage{ Role: PromptRoleUser, Blocks: []PromptBlock{{ diff --git a/bridges/ai/session_greeting.go b/bridges/ai/session_greeting.go index 7e9b75cd..f393dad9 100644 --- a/bridges/ai/session_greeting.go +++ b/bridges/ai/session_greeting.go @@ -5,7 +5,6 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" ) @@ -13,19 +12,6 @@ import ( const sessionGreetingPrompt = "A new session was started via !ai reset. Greet the user in your configured persona, if one is provided. Be yourself - use your defined voice, mannerisms, and mood. Keep it to 1-3 sentences and ask what they want to do. If the runtime model differs from default_model in the system prompt, mention the default model. Do not mention internal steps, files, tools, or reasoning." const autoGreetingPrompt = "A new chat was created. Greet the user in your configured persona, if one is provided. Be yourself - use your defined voice, mannerisms, and mood. Keep it to 1-3 sentences and ask what they want to do. If the runtime model differs from default_model in the system prompt, mention the default model. Do not mention internal steps, files, tools, or reasoning." -func maybePrependSessionGreeting( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, - log zerolog.Logger, -) []openai.ChatCompletionMessageParamUnion { - if greeting := sessionGreetingFragment(ctx, portal, meta, log); greeting != "" { - return append([]openai.ChatCompletionMessageParamUnion{openai.SystemMessage(greeting)}, prompt...) - } - return prompt -} - func sessionGreetingFragment( ctx context.Context, portal *bridgev2.Portal, diff --git a/bridges/ai/session_greeting_test.go b/bridges/ai/session_greeting_test.go index 4d801c00..15ecbad2 100644 --- a/bridges/ai/session_greeting_test.go +++ b/bridges/ai/session_greeting_test.go @@ -4,31 +4,23 @@ import ( "context" "testing" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" ) -func TestMaybePrependSessionGreeting(t *testing.T) { +func TestSessionGreetingFragment(t *testing.T) { ctx := context.Background() meta := agentModeTestMeta("beeper") - prompt := []openai.ChatCompletionMessageParamUnion{} - out := maybePrependSessionGreeting(ctx, nil, meta, prompt, zerolog.Nop()) - if len(out) != 1 { - t.Fatalf("expected 1 greeting message, got %d", len(out)) + greeting := sessionGreetingFragment(ctx, nil, meta, zerolog.Nop()) + if greeting != sessionGreetingPrompt { + t.Fatalf("expected greeting prompt, got %q", greeting) } if meta.SessionBootstrapByAgent == nil || meta.SessionBootstrapByAgent["beeper"] == 0 { t.Fatal("expected SessionBootstrapByAgent to be set") } - if out[0].OfSystem == nil { - t.Fatal("expected system message") - } - if out[0].OfSystem.Content.OfString.Value != sessionGreetingPrompt { - t.Fatalf("unexpected greeting content: %q", out[0].OfSystem.Content.OfString.Value) - } - out2 := maybePrependSessionGreeting(ctx, nil, meta, []openai.ChatCompletionMessageParamUnion{}, zerolog.Nop()) - if len(out2) != 0 { - t.Fatalf("expected no additional greeting, got %d", len(out2)) + greeting2 := sessionGreetingFragment(ctx, nil, meta, zerolog.Nop()) + if greeting2 != "" { + t.Fatalf("expected no additional greeting, got %q", greeting2) } } diff --git a/bridges/ai/streaming_continuation.go b/bridges/ai/streaming_continuation.go index 24ad4ffa..0c74d1b2 100644 --- a/bridges/ai/streaming_continuation.go +++ b/bridges/ai/streaming_continuation.go @@ -17,12 +17,10 @@ func (oc *AIClient) buildContinuationParams( pendingOutputs []functionCallOutput, approvalInputs []responses.ResponseInputItemUnionParam, ) responses.ResponseNewParams { - currentPrompt := PromptContext{} + var input responses.ResponseInputParam if prompt != nil { - currentPrompt = ClonePromptContext(*prompt) + input = append(input, promptContextToResponsesInput(*prompt)...) } - var input responses.ResponseInputParam - input = append(input, promptContextToResponsesInput(currentPrompt)...) input = append(input, approvalInputs...) for _, output := range pendingOutputs { if output.name != "" { diff --git a/bridges/ai/streaming_input_conversion.go b/bridges/ai/streaming_input_conversion.go index 72536264..6c021fc8 100644 --- a/bridges/ai/streaming_input_conversion.go +++ b/bridges/ai/streaming_input_conversion.go @@ -1,10 +1,5 @@ package ai -func promptHasAudioContent(prompt PromptContext) bool { - _ = prompt - return false -} - func promptHasMultimodalContent(prompt PromptContext) bool { for _, msg := range prompt.Messages { for _, block := range msg.Blocks { From 0ee13bc6cce8f6c080103c91e8be943c91553b9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 22:40:35 +0200 Subject: [PATCH 10/23] Update prompt_context_ops.go --- bridges/ai/prompt_context_ops.go | 1 - 1 file changed, 1 deletion(-) diff --git a/bridges/ai/prompt_context_ops.go b/bridges/ai/prompt_context_ops.go index 2dcff3e3..2a6a1933 100644 --- a/bridges/ai/prompt_context_ops.go +++ b/bridges/ai/prompt_context_ops.go @@ -26,7 +26,6 @@ func ClonePromptContext(ctx PromptContext) PromptContext { return cloned } - func PromptContextMessageCount(ctx PromptContext) int { count := len(ctx.Messages) if strings.TrimSpace(ctx.SystemPrompt) != "" { From fc457feb77929f40990a64342f1040b61f1b1476 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 22:49:45 +0200 Subject: [PATCH 11/23] Remove OpenAI chat provider and update AI helpers Remove deprecated OpenAI chat provider and related audio helper, wrap background sync call, fix OpenCode streaming logic, and replace room-meta sending with Bot.SendState. - bridges/ai/provider_openai_chat.go: delete OpenAI chat provider implementation. - bridges/ai/client.go: remove getAudioFormat helper. - bridges/openclaw/manager.go: wrap syncSessions invocation in a goroutine that explicitly discards the returned error. - bridges/opencode/opencode_canonical_stream.go: stop reassigning delivered when appending assistant text (prevents overwriting delivered state). - helpers.go: replace portal.Internal().SendRoomMeta usage with portal.Bridge.Bot.SendState, return bool success and log failures via zerolog (avoids linting on deprecated internals and uses explicit state API). - sdk/conversation_state_test.go: adjust test to ignore the boolean return from saveConversationStateToGenericMetadata with a clarifying comment. These changes consolidate provider removal, address a lint/deprecation concern, improve error visibility for async syncs, and correct streaming/text delivery behavior. --- bridges/ai/client.go | 21 -------- bridges/ai/provider_openai_chat.go | 54 ------------------- bridges/openclaw/manager.go | 4 +- bridges/opencode/opencode_canonical_stream.go | 1 - helpers.go | 20 ++++--- sdk/conversation_state_test.go | 6 +-- 6 files changed, 15 insertions(+), 91 deletions(-) delete mode 100644 bridges/ai/provider_openai_chat.go diff --git a/bridges/ai/client.go b/bridges/ai/client.go index 251536f7..9ccf2318 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -1886,27 +1886,6 @@ func (oc *AIClient) downloadAndEncodeMedia(ctx context.Context, mxcURL string, e return base64.StdEncoding.EncodeToString(data), mimeType, nil } -// getAudioFormat extracts the audio format from a MIME type for OpenRouter API -func getAudioFormat(mimeType string) string { - switch mimeType { - case "audio/wav", "audio/x-wav": - return "wav" - case "audio/mpeg", "audio/mp3": - return "mp3" - case "audio/webm": - return "webm" - case "audio/ogg": - return "ogg" - case "audio/flac": - return "flac" - case "audio/mp4", "audio/x-m4a": - return "mp4" - default: - // Default to mp3 for unknown formats - return "mp3" - } -} - // ensureGhostDisplayName ensures the ghost has its display name set before sending messages. // This fixes the issue where ghosts appear with raw user IDs instead of formatted names. func (oc *AIClient) ensureGhostDisplayName(ctx context.Context, modelID string) { diff --git a/bridges/ai/provider_openai_chat.go b/bridges/ai/provider_openai_chat.go deleted file mode 100644 index a9a6a260..00000000 --- a/bridges/ai/provider_openai_chat.go +++ /dev/null @@ -1,54 +0,0 @@ -package ai - -import ( - "context" - "errors" - "fmt" - - "github.com/openai/openai-go/v3" -) - -func (o *OpenAIProvider) generateChatCompletions(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { - chatMessages := promptContextToChatCompletionMessages(params.Context, isOpenRouterBaseURL(o.baseURL)) - if len(chatMessages) == 0 { - return nil, errors.New("no chat messages for completion") - } - - req := openai.ChatCompletionNewParams{ - Model: params.Model, - Messages: chatMessages, - } - if params.MaxCompletionTokens > 0 { - req.MaxCompletionTokens = openai.Int(int64(params.MaxCompletionTokens)) - } - if params.Temperature != nil { - req.Temperature = openai.Float(*params.Temperature) - } - if len(params.Context.Tools) > 0 { - req.Tools = ToOpenAIChatTools(params.Context.Tools, resolveToolStrictMode(isOpenRouterBaseURL(o.baseURL)), &o.log) - req.Tools = dedupeChatToolParams(req.Tools) - } - - resp, err := o.client.Chat.Completions.New(ctx, req) - if err != nil { - return nil, fmt.Errorf("OpenAI chat completion failed: %w", err) - } - - var content string - var finishReason string - if len(resp.Choices) > 0 { - content = resp.Choices[0].Message.Content - finishReason = resp.Choices[0].FinishReason - } - - return &GenerateResponse{ - Content: content, - FinishReason: finishReason, - Usage: UsageInfo{ - PromptTokens: int(resp.Usage.PromptTokens), - CompletionTokens: int(resp.Usage.CompletionTokens), - TotalTokens: int(resp.Usage.TotalTokens), - ReasoningTokens: int(resp.Usage.CompletionTokensDetails.ReasoningTokens), - }, - }, nil -} diff --git a/bridges/openclaw/manager.go b/bridges/openclaw/manager.go index 45f04526..35e0124e 100644 --- a/bridges/openclaw/manager.go +++ b/bridges/openclaw/manager.go @@ -593,7 +593,9 @@ func (m *openClawManager) HandleMatrixMessage(ctx context.Context, msg *bridgev2 return nil, err } if meta.OpenClawDMCreatedFromContact && meta.OpenClawSessionID == "" && isOpenClawSyntheticDMSessionKey(meta.OpenClawSessionKey) { - go m.syncSessions(m.client.BackgroundContext(ctx)) + go func() { + _ = m.syncSessions(m.client.BackgroundContext(ctx)) + }() } return &bridgev2.MatrixMessageResponse{Pending: true}, nil } diff --git a/bridges/opencode/opencode_canonical_stream.go b/bridges/opencode/opencode_canonical_stream.go index 84fc4c4f..4664a333 100644 --- a/bridges/opencode/opencode_canonical_stream.go +++ b/bridges/opencode/opencode_canonical_stream.go @@ -66,7 +66,6 @@ func (m *OpenCodeManager) syncAssistantTextPart(ctx context.Context, inst *openC "delta": text, }) inst.appendPartTextContent(part.SessionID, part.ID, kind, text) - delivered = text } } else if missing, ok := strings.CutPrefix(text, delivered); ok && missing != "" { m.bridge.emitOpenCodeStreamEvent(ctx, portal, turnID, agentID, map[string]any{ diff --git a/helpers.go b/helpers.go index df2c7362..dfd799e8 100644 --- a/helpers.go +++ b/helpers.go @@ -456,17 +456,15 @@ func SendAIRoomInfo(ctx context.Context, portal *bridgev2.Portal, aiKind string) if aiKind == "" { aiKind = AIRoomKindAgent } - //lint:ignore SA1019 bridgev2 currently exposes room-meta sending via portal internals - return portal.Internal().SendRoomMeta( - ctx, - nil, - time.Now(), - matrixevents.AIRoomInfoEventType, - "", - map[string]any{"type": aiKind}, - true, - nil, - ) + _, err := portal.Bridge.Bot.SendState(ctx, portal.MXID, matrixevents.AIRoomInfoEventType, "", &event.Content{ + Parsed: map[string]any{"type": aiKind}, + Raw: map[string]any{"com.beeper.exclude_from_timeline": true}, + }, time.Now()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to send AI room info state event") + return false + } + return true } // findExistingMessage performs a two-phase message lookup: first by network diff --git a/sdk/conversation_state_test.go b/sdk/conversation_state_test.go index a5775253..bbef53de 100644 --- a/sdk/conversation_state_test.go +++ b/sdk/conversation_state_test.go @@ -79,9 +79,9 @@ func TestConversationStateRoundTripCarrierMetadata(t *testing.T) { AgentIDs: []string{"agent-a"}, }, } - if !saveConversationStateToGenericMetadata(&holder, state) { - // Generic metadata intentionally doesn't support the carrier path. - } + // saveConversationStateToGenericMetadata intentionally returns false here + // because generic metadata doesn't support the carrier path. + _ = saveConversationStateToGenericMetadata(&holder, state) carrier.SetSDKPortalMetadata(&SDKPortalMetadata{Conversation: *state}) loaded, ok := carrier.GetSDKPortalMetadata(), carrier.GetSDKPortalMetadata() != nil if !ok || loaded == nil { From 61a72f250739454c9250eae507d7b3b0687a5136 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 23:41:28 +0200 Subject: [PATCH 12/23] Refactor integration types and introduce generics Migrate integration and connector code to stronger typed APIs and generics. Key changes: - Use integrationruntime.ModuleHooks and concrete integrationruntime types instead of generic any for module registration, scopes and callbacks. - Update bridgesdk.Config usages to the generic form (e.g. bridgesdk.Config[*AIClient, *Config]) and adjust NewStandardConnectorConfig parameter factories to return concrete pointer types. - Replace many host/handler signatures to accept concrete types (e.g. *bridgev2.Portal, *PortalMetadata, *dbutil.Database) and remove redundant runtime casts. - Add PortalMetadata helper accessors (AgentID, CompactionCounter, InternalRoom, ModuleMetaValue, SetModuleMetaValue). - Improve compaction retry logic by caching token estimates from preflight flush, return an int from the preflight hook, and emit lifecycle events without embedding client references. - Simplify header handling in media understanding helpers to avoid unnecessary copying. - Propagate truncated flag into file message XML and simplify XML building. - Rename and tighten dummybridge session helpers (sessionFromAny -> requireSession) and update related APIs to use typed sessions. - Adjust tests to use the new generic bridgesdk.NewConversation signatures and other updated APIs. These changes tighten type safety, reduce runtime type assertions, and prepare the codebase for clearer integration APIs. --- bridges/ai/client.go | 3 +- bridges/ai/command_registry.go | 1 - bridges/ai/connector.go | 2 +- bridges/ai/constructors.go | 10 +- bridges/ai/heartbeat_execute.go | 12 +- bridges/ai/integration_host.go | 170 +++++------------ bridges/ai/integrations.go | 33 ++-- bridges/ai/integrations_test.go | 2 + bridges/ai/media_understanding_runner.go | 16 +- bridges/ai/metadata.go | 26 +++ bridges/ai/response_finalization_test.go | 2 +- bridges/ai/response_retry.go | 20 +- bridges/ai/streaming_error_handling_test.go | 2 +- bridges/ai/streaming_init.go | 2 +- bridges/ai/streaming_output_items_test.go | 2 +- bridges/ai/streaming_ui_tools_test.go | 2 +- bridges/ai/system_prompts.go | 22 +-- bridges/ai/text_files.go | 9 +- bridges/codex/connector.go | 2 +- bridges/codex/constructors.go | 10 +- bridges/codex/stream_mapping_test.go | 2 +- bridges/dummybridge/bridge.go | 25 ++- bridges/dummybridge/connector.go | 12 +- bridges/dummybridge/runtime.go | 4 +- bridges/dummybridge/runtime_test.go | 2 +- bridges/openclaw/connector.go | 12 +- bridges/openclaw/stream_test.go | 4 +- bridges/opencode/connector.go | 14 +- pkg/integrations/cron/integration.go | 5 +- pkg/integrations/memory/integration.go | 103 ++++------- pkg/integrations/memory/manager.go | 6 +- pkg/integrations/memory/module_exec.go | 1 - pkg/integrations/memory/overflow_exec.go | 20 +- pkg/integrations/memory/prompt_exec.go | 18 +- pkg/integrations/memory/sessions.go | 6 +- pkg/integrations/runtime/helpers.go | 5 +- pkg/integrations/runtime/host_types.go | 17 +- pkg/integrations/runtime/interfaces.go | 19 +- pkg/integrations/runtime/module_hooks.go | 99 +++------- sdk/client.go | 193 ++++++++++++-------- sdk/client_resolution_test.go | 10 +- sdk/commands.go | 2 +- sdk/connector.go | 4 +- sdk/connector_helpers.go | 35 ++-- sdk/connector_hooks_test.go | 6 +- sdk/conversation.go | 45 ++--- sdk/conversation_test.go | 10 +- sdk/login_handle.go | 4 +- sdk/part_apply_test.go | 2 +- sdk/runtime.go | 65 +++++-- sdk/turn.go | 4 +- sdk/turn_test.go | 10 +- sdk/types.go | 36 ++-- 53 files changed, 554 insertions(+), 594 deletions(-) diff --git a/bridges/ai/client.go b/bridges/ai/client.go index 9ccf2318..1d277789 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -26,6 +26,7 @@ import ( "github.com/beeper/agentremote" "github.com/beeper/agentremote/pkg/agents" + integrationruntime "github.com/beeper/agentremote/pkg/integrations/runtime" airuntime "github.com/beeper/agentremote/pkg/runtime" "github.com/beeper/agentremote/pkg/shared/stringutil" ) @@ -311,7 +312,7 @@ type AIClient struct { // Heartbeat + integrations scheduler *schedulerRuntime - integrationModules map[string]any + integrationModules map[string]integrationruntime.ModuleHooks integrationOrder []string toolRegistry *toolIntegrationRegistry diff --git a/bridges/ai/command_registry.go b/bridges/ai/command_registry.go index 84fba72f..57afaefe 100644 --- a/bridges/ai/command_registry.go +++ b/bridges/ai/command_registry.go @@ -96,7 +96,6 @@ func registerModuleCommands(defs []integrationruntime.CommandDefinition) { ce.Ctx, ce.Portal, meta, - ce, commandName, ce.Args, ce.RawArgs, diff --git a/bridges/ai/connector.go b/bridges/ai/connector.go index f6d822c7..1cc0c84e 100644 --- a/bridges/ai/connector.go +++ b/bridges/ai/connector.go @@ -37,7 +37,7 @@ type OpenAIConnector struct { br *bridgev2.Bridge Config Config db *dbutil.Database - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*AIClient, *Config] clientsMu sync.Mutex clients map[networkid.UserLoginID]bridgev2.NetworkAPI diff --git a/bridges/ai/constructors.go b/bridges/ai/constructors.go index dbc31714..920359e7 100644 --- a/bridges/ai/constructors.go +++ b/bridges/ai/constructors.go @@ -17,7 +17,7 @@ func NewAIConnector() *OpenAIConnector { oc := &OpenAIConnector{ clients: make(map[networkid.UserLoginID]bridgev2.NetworkAPI), } - oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*AIClient, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "ai", Description: "AI Chats for Beeper, built on mautrix-go bridgev2.", ProtocolID: "ai", @@ -62,10 +62,10 @@ func NewAIConnector() *OpenAIConnector { }, ExampleConfig: exampleNetworkConfig, ConfigData: &oc.Config, - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, FillBridgeInfo: func(portal *bridgev2.Portal, content *event.BridgeEventContent) { applyAgentRemoteBridgeInfo(portal, portalMeta(portal), content) }, diff --git a/bridges/ai/heartbeat_execute.go b/bridges/ai/heartbeat_execute.go index 357fc908..cb1e37a2 100644 --- a/bridges/ai/heartbeat_execute.go +++ b/bridges/ai/heartbeat_execute.go @@ -79,7 +79,7 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, sessionResolution := oc.resolveHeartbeatSession(agentID, heartbeat) storeKey := strings.TrimSpace(sessionResolution.SessionKey) - sessionPortal, sessionKey, err := oc.resolveHeartbeatSessionPortal(agentID, heartbeat) + sessionPortal, sessionKey, err := oc.resolveHeartbeatSessionPortal(agentID, heartbeat, sessionResolution) if err != nil || sessionPortal == nil || sessionPortal.MXID == "" { oc.log.Warn().Str("agent_id", agentID).Err(err).Msg("Heartbeat skipped: no session portal") return heartbeatRunResult{Status: "skipped", Reason: "no-session"} @@ -262,7 +262,13 @@ func systemEventsOwnerKey(oc *AIClient) string { return string(oc.UserLogin.Bridge.DB.BridgeID) + "|" + string(oc.UserLogin.ID) } -func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *HeartbeatConfig) (*bridgev2.Portal, string, error) { +func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *HeartbeatConfig, preResolved ...heartbeatSessionResolution) (*bridgev2.Portal, string, error) { + var hbSession heartbeatSessionResolution + if len(preResolved) > 0 && preResolved[0].SessionKey != "" { + hbSession = preResolved[0] + } else { + hbSession = oc.resolveHeartbeatSession(agentID, heartbeat) + } session := "" if heartbeat != nil && heartbeat.Session != nil { session = strings.TrimSpace(*heartbeat.Session) @@ -272,7 +278,6 @@ func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *Hea mainKey = strings.TrimSpace(oc.connector.Config.Session.MainKey) } if session == "" || strings.EqualFold(session, "main") || strings.EqualFold(session, "global") || (mainKey != "" && strings.EqualFold(session, mainKey)) { - hbSession := oc.resolveHeartbeatSession(agentID, heartbeat) if portal := oc.heartbeatSessionPortalCandidate(agentID, hbSession); portal != nil { return portal, portal.MXID.String(), nil } @@ -291,7 +296,6 @@ func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *Hea } } } - hbSession := oc.resolveHeartbeatSession(agentID, heartbeat) if portal := oc.heartbeatSessionPortalCandidate(agentID, hbSession); portal != nil { return portal, portal.MXID.String(), nil } diff --git a/bridges/ai/integration_host.go b/bridges/ai/integration_host.go index 90325d25..1f6c25aa 100644 --- a/bridges/ai/integration_host.go +++ b/bridges/ai/integration_host.go @@ -9,6 +9,7 @@ import ( "github.com/openai/openai-go/v3" "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -120,7 +121,7 @@ func (h *runtimeIntegrationHost) AgentModuleConfig(agentID string, module string // ---- Host methods: logger access ---- -func (h *runtimeIntegrationHost) RawLogger() any { +func (h *runtimeIntegrationHost) RawLogger() zerolog.Logger { if h == nil || h.client == nil { return zerolog.Logger{} } @@ -129,7 +130,7 @@ func (h *runtimeIntegrationHost) RawLogger() any { // ---- Host methods: portal management ---- -func (h *runtimeIntegrationHost) GetOrCreatePortal(ctx context.Context, portalID string, receiver string, displayName string, setupMeta func(meta any)) (portal any, roomID string, err error) { +func (h *runtimeIntegrationHost) GetOrCreatePortal(ctx context.Context, portalID string, receiver string, displayName string, setupMeta func(meta *PortalMetadata)) (portal *bridgev2.Portal, roomID string, err error) { if h == nil || h.client == nil || h.client.UserLogin == nil { return nil, "", fmt.Errorf("missing login") } @@ -155,122 +156,55 @@ func (h *runtimeIntegrationHost) GetOrCreatePortal(ctx context.Context, portalID return p, p.MXID.String(), nil } -func (h *runtimeIntegrationHost) SavePortal(ctx context.Context, portal any, reason string) error { +func (h *runtimeIntegrationHost) SavePortal(ctx context.Context, portal *bridgev2.Portal, reason string) error { if h == nil || h.client == nil { return nil } - p, _ := portal.(*bridgev2.Portal) - if p == nil { + if portal == nil { return nil } - h.client.savePortalQuiet(ctx, p, reason) + h.client.savePortalQuiet(ctx, portal, reason) return nil } -func (h *runtimeIntegrationHost) PortalRoomID(portal any) string { - p, _ := portal.(*bridgev2.Portal) - if p == nil { +func (h *runtimeIntegrationHost) PortalRoomID(portal *bridgev2.Portal) string { + if portal == nil { return "" } - return p.MXID.String() + return portal.MXID.String() } -func (h *runtimeIntegrationHost) PortalKeyString(portal any) string { - p, _ := portal.(*bridgev2.Portal) - if p == nil { +func (h *runtimeIntegrationHost) PortalKeyString(portal *bridgev2.Portal) string { + if portal == nil { return "" } - return p.PortalKey.String() + return portal.PortalKey.String() } -// ---- Host methods: metadata access ---- - -func (h *runtimeIntegrationHost) GetModuleMeta(meta any, key string) any { - m, _ := meta.(*PortalMetadata) - if m == nil || m.ModuleMeta == nil { - return nil - } - return m.ModuleMeta[key] -} - -func (h *runtimeIntegrationHost) SetModuleMeta(meta any, key string, value any) { - m, _ := meta.(*PortalMetadata) - if m == nil { - return - } - m.SetModuleMeta(key, value) -} - -func (h *runtimeIntegrationHost) AgentIDFromMeta(meta any) string { - m, _ := meta.(*PortalMetadata) - return resolveAgentID(m) -} - -func (h *runtimeIntegrationHost) CompactionCount(meta any) int { - m, _ := meta.(*PortalMetadata) - if m == nil { - return 0 - } - return m.CompactionCount -} - -func (h *runtimeIntegrationHost) IsGroupChat(ctx context.Context, portal any) bool { +func (h *runtimeIntegrationHost) IsGroupChat(ctx context.Context, portal *bridgev2.Portal) bool { if h == nil || h.client == nil { return false } - p, _ := portal.(*bridgev2.Portal) - if p == nil { + if portal == nil { return false } - return h.client.isGroupChat(ctx, p) -} - -func (h *runtimeIntegrationHost) IsInternalRoom(meta any) bool { - m, _ := meta.(*PortalMetadata) - if m == nil { - return false - } - return isModuleInternalRoom(m) -} - -func (h *runtimeIntegrationHost) PortalMeta(portal any) any { - p, _ := portal.(*bridgev2.Portal) - return portalMeta(p) -} - -func (h *runtimeIntegrationHost) CloneMeta(portal any) any { - p, _ := portal.(*bridgev2.Portal) - return clonePortalMetadata(portalMeta(p)) -} - -func (h *runtimeIntegrationHost) SetMetaField(meta any, key string, value any) { - m, _ := meta.(*PortalMetadata) - if m == nil { - return - } - switch key { - case "DisabledTools": - if v, ok := value.([]string); ok { - m.DisabledTools = v - } - } + return h.client.isGroupChat(ctx, portal) } // ---- Host methods: message helpers ---- -func (h *runtimeIntegrationHost) RecentMessages(ctx context.Context, portal any, count int) []integrationruntime.MessageSummary { +func (h *runtimeIntegrationHost) RecentMessages(ctx context.Context, portal *bridgev2.Portal, count int) []integrationruntime.MessageSummary { if h == nil || h.client == nil { return nil } - p, _ := portal.(*bridgev2.Portal) - if p == nil || count <= 0 || h.client.UserLogin == nil || h.client.UserLogin.Bridge == nil || h.client.UserLogin.Bridge.DB == nil { + if portal == nil || count <= 0 || h.client.UserLogin == nil || h.client.UserLogin.Bridge == nil || h.client.UserLogin.Bridge.DB == nil { return nil } maxMessages := count if maxMessages > 10 { maxMessages = 10 } - history, err := h.client.UserLogin.Bridge.DB.Message.GetLastNInPortal(h.client.backgroundContext(ctx), p.PortalKey, maxMessages) + history, err := h.client.UserLogin.Bridge.DB.Message.GetLastNInPortal(h.client.backgroundContext(ctx), portal.PortalKey, maxMessages) if err != nil || len(history) == 0 { return nil } @@ -293,20 +227,18 @@ func (h *runtimeIntegrationHost) RecentMessages(ctx context.Context, portal any, return out } -func (h *runtimeIntegrationHost) LastAssistantMessage(ctx context.Context, portal any) (id string, timestamp int64) { +func (h *runtimeIntegrationHost) LastAssistantMessage(ctx context.Context, portal *bridgev2.Portal) (id string, timestamp int64) { if h == nil || h.client == nil { return "", 0 } - p, _ := portal.(*bridgev2.Portal) - return h.client.lastAssistantMessageInfo(ctx, p) + return h.client.lastAssistantMessageInfo(ctx, portal) } -func (h *runtimeIntegrationHost) WaitForAssistantMessage(ctx context.Context, portal any, afterID string, afterTS int64) (*integrationruntime.AssistantMessageInfo, bool) { +func (h *runtimeIntegrationHost) WaitForAssistantMessage(ctx context.Context, portal *bridgev2.Portal, afterID string, afterTS int64) (*integrationruntime.AssistantMessageInfo, bool) { if h == nil || h.client == nil { return nil, false } - p, _ := portal.(*bridgev2.Portal) - msg, found := h.client.waitForNewAssistantMessage(ctx, p, afterID, afterTS) + msg, found := h.client.waitForNewAssistantMessage(ctx, portal, afterID, afterTS) if !found || msg == nil { return nil, false } @@ -331,7 +263,7 @@ func (h *runtimeIntegrationHost) RunHeartbeatOnce(ctx context.Context, reason st return h.client.scheduler.RunHeartbeatSweep(ctx, reason) } -func (h *runtimeIntegrationHost) ResolveHeartbeatSessionPortal(agentID string) (portal any, sessionKey string, err error) { +func (h *runtimeIntegrationHost) ResolveHeartbeatSessionPortal(agentID string) (portal *bridgev2.Portal, sessionKey string, err error) { if h == nil || h.client == nil { return nil, "", fmt.Errorf("missing client") } @@ -450,7 +382,7 @@ func (h *runtimeIntegrationHost) NormalizeThinkingLevel(raw string) (string, boo // ---- Host methods: model helpers ---- -func (h *runtimeIntegrationHost) EffectiveModel(meta any) string { +func (h *runtimeIntegrationHost) EffectiveModel(meta integrationruntime.Meta) string { if h == nil || h.client == nil { return "" } @@ -458,7 +390,7 @@ func (h *runtimeIntegrationHost) EffectiveModel(meta any) string { return h.client.effectiveModel(m) } -func (h *runtimeIntegrationHost) ContextWindow(meta any) int { +func (h *runtimeIntegrationHost) ContextWindow(meta integrationruntime.Meta) int { if h == nil || h.client == nil { return 0 } @@ -509,15 +441,14 @@ func (h *runtimeIntegrationHost) BackgroundContext(ctx context.Context) context. // ---- Host methods: chat completions ---- -func (h *runtimeIntegrationHost) NewCompletion(ctx context.Context, model string, messages []openai.ChatCompletionMessageParamUnion, toolParams any) (*integrationruntime.CompletionResult, error) { +func (h *runtimeIntegrationHost) NewCompletion(ctx context.Context, model string, messages []openai.ChatCompletionMessageParamUnion, toolParams []openai.ChatCompletionToolUnionParam) (*integrationruntime.CompletionResult, error) { if h == nil || h.client == nil { return nil, fmt.Errorf("missing client") } - params, _ := toolParams.([]openai.ChatCompletionToolUnionParam) req := openai.ChatCompletionNewParams{ Model: model, Messages: messages, - Tools: params, + Tools: toolParams, } resp, err := h.client.api.Chat.Completions.New(ctx, req) if err != nil { @@ -549,7 +480,7 @@ func (h *runtimeIntegrationHost) NewCompletion(ctx context.Context, model string // ---- Host methods: tool policy ---- -func (h *runtimeIntegrationHost) IsToolEnabled(meta any, toolName string) bool { +func (h *runtimeIntegrationHost) IsToolEnabled(meta integrationruntime.Meta, toolName string) bool { if h == nil || h.client == nil { return true } @@ -567,24 +498,23 @@ func (h *runtimeIntegrationHost) AllToolDefinitions() []integrationruntime.ToolD return out } -func (h *runtimeIntegrationHost) ExecuteToolInContext(ctx context.Context, portal any, meta any, name string, argsJSON string) (string, error) { +func (h *runtimeIntegrationHost) ExecuteToolInContext(ctx context.Context, portal *bridgev2.Portal, meta integrationruntime.Meta, name string, argsJSON string) (string, error) { if h == nil || h.client == nil { return "", fmt.Errorf("missing client") } - p, _ := portal.(*bridgev2.Portal) m, _ := meta.(*PortalMetadata) if m != nil && !h.client.isToolEnabled(m, name) { return "", fmt.Errorf("tool %s is disabled", name) } toolCtx := WithBridgeToolContext(ctx, &BridgeToolContext{ Client: h.client, - Portal: p, + Portal: portal, Meta: m, }) - return h.client.executeBuiltinTool(toolCtx, p, name, argsJSON) + return h.client.executeBuiltinTool(toolCtx, portal, name, argsJSON) } -func (h *runtimeIntegrationHost) ToolsToOpenAIParams(tools []integrationruntime.ToolDefinition) any { +func (h *runtimeIntegrationHost) ToolsToOpenAIParams(tools []integrationruntime.ToolDefinition) []openai.ChatCompletionToolUnionParam { if h == nil || h.client == nil { return nil } @@ -614,7 +544,7 @@ func (h *runtimeIntegrationHost) ReadTextFile(ctx context.Context, agentID strin return entry.Content, entry.Path, true, nil } -func (h *runtimeIntegrationHost) WriteTextFile(ctx context.Context, portal any, meta any, agentID string, mode string, path string, content string, maxBytes int) (finalPath string, err error) { +func (h *runtimeIntegrationHost) WriteTextFile(ctx context.Context, portal *bridgev2.Portal, meta integrationruntime.Meta, agentID string, mode string, path string, content string, maxBytes int) (finalPath string, err error) { if h == nil || h.client == nil { return "", fmt.Errorf("storage unavailable") } @@ -644,11 +574,10 @@ func (h *runtimeIntegrationHost) WriteTextFile(ctx context.Context, portal any, return "", e } if entry != nil { - p, _ := portal.(*bridgev2.Portal) m, _ := meta.(*PortalMetadata) toolCtx := WithBridgeToolContext(ctx, &BridgeToolContext{ Client: h.client, - Portal: p, + Portal: portal, Meta: m, }) notifyIntegrationFileChanged(toolCtx, entry.Path) @@ -766,7 +695,7 @@ func (h *runtimeIntegrationHost) SessionPortals(ctx context.Context, loginID str return out, nil } -func (h *runtimeIntegrationHost) LoginDB() any { +func (h *runtimeIntegrationHost) LoginDB() *dbutil.Database { if h == nil || h.client == nil { return nil } @@ -819,52 +748,49 @@ func (h *runtimeIntegrationHost) CronRun(ctx context.Context, jobID string) (boo // ---- Host methods: dispatch/lookup primitives ---- -func (h *runtimeIntegrationHost) ResolvePortalByRoomID(ctx context.Context, roomID string) any { +func (h *runtimeIntegrationHost) ResolvePortalByRoomID(ctx context.Context, roomID string) *bridgev2.Portal { if h == nil || h.client == nil || strings.TrimSpace(roomID) == "" { return nil } return h.client.portalByRoomID(ctx, portalRoomIDFromString(roomID)) } -func (h *runtimeIntegrationHost) ResolveDefaultPortal(ctx context.Context) any { +func (h *runtimeIntegrationHost) ResolveDefaultPortal(ctx context.Context) *bridgev2.Portal { if h == nil || h.client == nil { return nil } return h.client.defaultChatPortal() } -func (h *runtimeIntegrationHost) ResolveLastActivePortal(ctx context.Context, agentID string) any { +func (h *runtimeIntegrationHost) ResolveLastActivePortal(ctx context.Context, agentID string) *bridgev2.Portal { if h == nil || h.client == nil { return nil } return h.client.lastActivePortal(agentID) } -func (h *runtimeIntegrationHost) DispatchInternalMessage(ctx context.Context, portal any, meta any, message string, source string) error { +func (h *runtimeIntegrationHost) DispatchInternalMessage(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, message string, source string) error { if h == nil || h.client == nil { return fmt.Errorf("missing client") } - p, _ := portal.(*bridgev2.Portal) - if p == nil { + if portal == nil { return fmt.Errorf("missing portal") } - m, _ := meta.(*PortalMetadata) - if m == nil { - m = &PortalMetadata{} + if meta == nil { + meta = &PortalMetadata{} } - _, _, err := h.client.dispatchInternalMessage(ctx, p, m, message, source, false) + _, _, err := h.client.dispatchInternalMessage(ctx, portal, meta, message, source, false) return err } -func (h *runtimeIntegrationHost) SendAssistantMessage(ctx context.Context, portal any, body string) error { +func (h *runtimeIntegrationHost) SendAssistantMessage(ctx context.Context, portal *bridgev2.Portal, body string) error { if h == nil || h.client == nil { return fmt.Errorf("missing client") } - p, _ := portal.(*bridgev2.Portal) - if p == nil { + if portal == nil { return fmt.Errorf("missing portal") } - return h.client.sendPlainAssistantMessage(ctx, p, body) + return h.client.sendPlainAssistantMessage(ctx, portal, body) } func (h *runtimeIntegrationHost) RequestNow(ctx context.Context, reason string) { @@ -887,7 +813,7 @@ func (h *runtimeIntegrationHost) ExecuteBuiltinTool(ctx context.Context, scope i if h == nil || h.client == nil { return "", fmt.Errorf("missing client") } - portal, _ := scope.Portal.(*bridgev2.Portal) + portal := scope.Portal meta, _ := scope.Meta.(*PortalMetadata) if meta != nil && !h.client.isToolEnabled(meta, name) { return "", fmt.Errorf("tool %s is disabled", name) @@ -904,7 +830,7 @@ func (h *runtimeIntegrationHost) ResolveWorkspaceDir() string { return resolvePromptWorkspaceDir() } -func (h *runtimeIntegrationHost) BridgeDB() any { +func (h *runtimeIntegrationHost) BridgeDB() *dbutil.Database { if h == nil || h.client == nil { return nil } diff --git a/bridges/ai/integrations.go b/bridges/ai/integrations.go index 9f8e9177..86db2db3 100644 --- a/bridges/ai/integrations.go +++ b/bridges/ai/integrations.go @@ -219,18 +219,15 @@ func settingSourceFromIntegration(source integrationruntime.SettingSource) Setti func (oc *AIClient) toolScope(portal *bridgev2.Portal, meta *PortalMetadata) integrationruntime.ToolScope { return integrationruntime.ToolScope{ - Client: oc, Portal: portal, Meta: meta, } } -func (oc *AIClient) commandScope(portal *bridgev2.Portal, meta *PortalMetadata, evt any) integrationruntime.CommandScope { +func (oc *AIClient) commandScope(portal *bridgev2.Portal, meta *PortalMetadata) integrationruntime.CommandScope { return integrationruntime.CommandScope{ - Client: oc, Portal: portal, Meta: meta, - Event: evt, } } @@ -243,7 +240,7 @@ func (oc *AIClient) initIntegrations() { oc.eventRegistry = &eventIntegrationRegistry{} oc.purgeRegistry = &purgeIntegrationRegistry{} oc.approvalRegistry = &toolApprovalIntegrationRegistry{} - oc.integrationModules = make(map[string]any) + oc.integrationModules = make(map[string]integrationruntime.ModuleHooks) oc.integrationOrder = nil host := newRuntimeIntegrationHost(oc) @@ -258,7 +255,7 @@ func (oc *AIClient) initIntegrations() { oc.toolRegistry.register(toolIntegration) } if commandIntegration, ok := module.(integrationruntime.CommandIntegration); ok { - defs := commandIntegration.CommandDefinitions(context.Background(), oc.commandScope(nil, nil, nil)) + defs := commandIntegration.CommandDefinitions(context.Background(), oc.commandScope(nil, nil)) oc.commandRegistry.register(commandIntegration, defs) } if eventIntegration, ok := module.(integrationruntime.EventIntegration); ok { @@ -286,7 +283,7 @@ func (oc *AIClient) integratedToolApprovalRequirement(toolName string, args map[ return oc.approvalRegistry.requirement(toolName, args) } -func (oc *AIClient) registerIntegrationModule(name string, module any) { +func (oc *AIClient) registerIntegrationModule(name string, module integrationruntime.ModuleHooks) { if oc == nil || module == nil { return } @@ -295,7 +292,7 @@ func (oc *AIClient) registerIntegrationModule(name string, module any) { return } if oc.integrationModules == nil { - oc.integrationModules = make(map[string]any) + oc.integrationModules = make(map[string]integrationruntime.ModuleHooks) } if _, exists := oc.integrationModules[key]; exists { return @@ -331,14 +328,14 @@ func (oc *AIClient) emitCompactionLifecycle( } } -func (oc *AIClient) integrationModule(name string) any { +func (oc *AIClient) integrationModule(name string) integrationruntime.ModuleHooks { if oc == nil || oc.integrationModules == nil { return nil } return oc.integrationModules[strings.ToLower(strings.TrimSpace(name))] } -func (oc *AIClient) eachIntegrationModule(fn func(name string, module any)) { +func (oc *AIClient) eachIntegrationModule(fn func(name string, module integrationruntime.ModuleHooks)) { if oc == nil || fn == nil || len(oc.integrationOrder) == 0 { return } @@ -355,7 +352,7 @@ func (oc *AIClient) startLifecycleIntegrations(ctx context.Context) { if oc == nil { return } - oc.eachIntegrationModule(func(name string, module any) { + oc.eachIntegrationModule(func(name string, module integrationruntime.ModuleHooks) { lifecycle, ok := module.(integrationruntime.LifecycleIntegration) if !ok { return @@ -415,7 +412,7 @@ func (oc *AIClient) stopLoginLifecycleIntegrations(bridgeID, loginID string) { if oc == nil || strings.TrimSpace(bridgeID) == "" || strings.TrimSpace(loginID) == "" { return } - oc.eachIntegrationModule(func(_ string, module any) { + oc.eachIntegrationModule(func(_ string, module integrationruntime.ModuleHooks) { loginLifecycle, ok := module.(integrationruntime.LoginLifecycleIntegration) if !ok { return @@ -472,7 +469,6 @@ func (oc *AIClient) executeIntegratedCommand( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, - evt any, name string, args []string, rawArgs string, @@ -485,7 +481,7 @@ func (oc *AIClient) executeIntegratedCommand( Name: name, Args: args, RawArgs: rawArgs, - Scope: oc.commandScope(portal, meta, evt), + Scope: oc.commandScope(portal, meta), Reply: reply, }) } @@ -501,7 +497,6 @@ func (oc *AIClient) emitIntegrationSessionMutation( return } oc.eventRegistry.sessionMutation(ctx, integrationruntime.SessionMutationEvent{ - Client: oc, Portal: portal, Meta: meta, SessionKey: portal.PortalKey.String(), @@ -515,7 +510,6 @@ func (oc *AIClient) emitIntegrationFileChanged(ctx context.Context, portal *brid return } oc.eventRegistry.fileChanged(ctx, integrationruntime.FileChangedEvent{ - Client: oc, Portal: portal, Meta: meta, Path: path, @@ -539,13 +533,11 @@ func notifyIntegrationFileChanged(ctx context.Context, path string) { btc.Client.emitIntegrationFileChanged(ctx, btc.Portal, meta, path) } -func (oc *AIClient) purgeLoginIntegrations(ctx context.Context, login any, bridgeID, loginID string) { +func (oc *AIClient) purgeLoginIntegrations(ctx context.Context, _ *bridgev2.UserLogin, bridgeID, loginID string) { if oc == nil || oc.purgeRegistry == nil { return } if err := oc.purgeRegistry.purge(ctx, integrationruntime.LoginScope{ - Client: oc, - Login: login, BridgeID: bridgeID, LoginID: loginID, }); err != nil { @@ -612,8 +604,7 @@ func (c *coreToolIntegration) ExecuteTool(ctx context.Context, call integrationr } args = parsedArgs } - portal, _ := call.Scope.Portal.(*bridgev2.Portal) - result, err := c.client.executeBuiltinToolDirect(ctx, portal, call.Name, args) + result, err := c.client.executeBuiltinToolDirect(ctx, call.Scope.Portal, call.Name, args) if err != nil { return true, "", err } diff --git a/bridges/ai/integrations_test.go b/bridges/ai/integrations_test.go index b970a6f5..5be5fea4 100644 --- a/bridges/ai/integrations_test.go +++ b/bridges/ai/integrations_test.go @@ -50,6 +50,8 @@ type fakeLifecycleIntegration struct { name string } +func (f *fakeLifecycleIntegration) Name() string { return f.name } + func (f *fakeLifecycleIntegration) Start(_ context.Context) error { f.startCount++ return nil diff --git a/bridges/ai/media_understanding_runner.go b/bridges/ai/media_understanding_runner.go index f4d4ca4f..69253d89 100644 --- a/bridges/ai/media_understanding_runner.go +++ b/bridges/ai/media_understanding_runner.go @@ -210,11 +210,9 @@ func (oc *AIClient) applyMediaUnderstandingForAttachments( } func (oc *AIClient) resolveAutoAudioEntry(cfg *MediaUnderstandingConfig) *MediaUnderstandingModelConfig { - headers := map[string]string{} - if cfg != nil && cfg.Headers != nil { - for key, value := range cfg.Headers { - headers[key] = value - } + var headers map[string]string + if cfg != nil { + headers = cfg.Headers } candidates := []struct { @@ -332,11 +330,9 @@ func (oc *AIClient) resolveKeyMediaEntry( } func (oc *AIClient) hasMediaProviderAuth(providerID string, cfg *MediaUnderstandingConfig) bool { - headers := map[string]string{} - if cfg != nil && cfg.Headers != nil { - for key, value := range cfg.Headers { - headers[key] = value - } + var headers map[string]string + if cfg != nil { + headers = cfg.Headers } if hasProviderAuthHeader(providerID, headers) { return true diff --git a/bridges/ai/metadata.go b/bridges/ai/metadata.go index 8b54aa4a..1afe73ed 100644 --- a/bridges/ai/metadata.go +++ b/bridges/ai/metadata.go @@ -311,6 +311,32 @@ func (m *PortalMetadata) SetModuleMeta(key string, value any) { m.ModuleMeta[key] = value } +func (m *PortalMetadata) ModuleMetaValue(key string) any { + if m == nil || m.ModuleMeta == nil { + return nil + } + return m.ModuleMeta[key] +} + +func (m *PortalMetadata) SetModuleMetaValue(key string, value any) { + m.SetModuleMeta(key, value) +} + +func (m *PortalMetadata) AgentID() string { + return resolveAgentID(m) +} + +func (m *PortalMetadata) CompactionCounter() int { + if m == nil { + return 0 + } + return m.CompactionCount +} + +func (m *PortalMetadata) InternalRoom() bool { + return isModuleInternalRoom(m) +} + func cloneUserLoginMetadata(src *UserLoginMetadata) (*UserLoginMetadata, error) { if src == nil { return &UserLoginMetadata{}, nil diff --git a/bridges/ai/response_finalization_test.go b/bridges/ai/response_finalization_test.go index 83612a79..5f61bdaa 100644 --- a/bridges/ai/response_finalization_test.go +++ b/bridges/ai/response_finalization_test.go @@ -14,7 +14,7 @@ import ( ) func testStreamingState(turnID string) *streamingState { - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) turn := conv.StartTurn(context.Background(), nil, nil) turn.SetID(turnID) return &streamingState{ diff --git a/bridges/ai/response_retry.go b/bridges/ai/response_retry.go index f742c99d..a140a893 100644 --- a/bridges/ai/response_retry.go +++ b/bridges/ai/response_retry.go @@ -35,12 +35,13 @@ func (oc *AIClient) responseWithRetry( currentPrompt := ClonePromptContext(prompt) preflightFlushAttempted := false overflowCompactionAttempts := 0 + cachedTokenEstimate := -1 var lastCLE *ContextLengthError for attempt := range maxRetryAttempts { if !preflightFlushAttempted { preflightFlushAttempted = true - oc.runCompactionPreflightFlushHook(ctx, portal, meta, currentPrompt, attempt+1) + cachedTokenEstimate = oc.runCompactionPreflightFlushHook(ctx, portal, meta, currentPrompt, attempt+1) } success, cle, err := responseFn(ctx, evt, portal, meta, currentPrompt) @@ -69,7 +70,11 @@ func (oc *AIClient) responseWithRetry( if meta != nil { modelID = oc.effectiveModel(meta) } - tokensBefore := estimatePromptContextTokensForModel(currentPrompt, modelID) + tokensBefore := cachedTokenEstimate + if tokensBefore < 0 { + tokensBefore = estimatePromptContextTokensForModel(currentPrompt, modelID) + } + cachedTokenEstimate = -1 // invalidate after use if overflowCompactionAttempts < maxRetryAttempts { overflowCompactionAttempts++ @@ -77,7 +82,6 @@ func (oc *AIClient) responseWithRetry( // Emit compaction start event. oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Phase: integrationruntime.CompactionLifecycleStart, @@ -117,7 +121,6 @@ func (oc *AIClient) responseWithRetry( WillRetry: true, }) oc.emitCompactionLifecyclePhases(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Attempt: attempt + 1, @@ -159,7 +162,6 @@ func (oc *AIClient) responseWithRetry( WillRetry: true, }) oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Phase: integrationruntime.CompactionLifecycleEnd, @@ -189,7 +191,6 @@ func (oc *AIClient) responseWithRetry( Error: "compaction did not reduce context sufficiently", }) oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Phase: integrationruntime.CompactionLifecycleFail, @@ -235,9 +236,9 @@ func (oc *AIClient) runCompactionPreflightFlushHook( meta *PortalMetadata, prompt PromptContext, attempt int, -) { +) int { if oc == nil || meta == nil { - return + return -1 } contextWindow := oc.getModelContextWindow(meta) if contextWindow <= 0 { @@ -247,7 +248,6 @@ func (oc *AIClient) runCompactionPreflightFlushHook( promptTokens := estimatePromptContextTokensForModel(prompt, modelID) projectedTokens := projectedCompactionFlushTokens(meta, promptTokens) oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Phase: integrationruntime.CompactionLifecyclePreFlush, @@ -262,6 +262,7 @@ func (oc *AIClient) runCompactionPreflightFlushHook( RequestedTokens: projectedTokens, ModelMaxTokens: contextWindow, }, attempt) + return promptTokens } func projectedCompactionFlushTokens(meta *PortalMetadata, promptTokens int) int { @@ -337,7 +338,6 @@ func (oc *AIClient) runCompactionFlushHook( return } hook.OnContextOverflow(ctx, integrationruntime.ContextOverflowCall{ - Client: oc, Portal: portal, Meta: meta, Prompt: promptContextToChatCompletionMessages(prompt, false), diff --git a/bridges/ai/streaming_error_handling_test.go b/bridges/ai/streaming_error_handling_test.go index b981929f..6ae8fbfa 100644 --- a/bridges/ai/streaming_error_handling_test.go +++ b/bridges/ai/streaming_error_handling_test.go @@ -14,7 +14,7 @@ import ( func newTestStreamingStateWithTurn() *streamingState { state := newStreamingState(context.Background(), nil, "") - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) state.turn = conv.StartTurn(context.Background(), nil, nil) return state } diff --git a/bridges/ai/streaming_init.go b/bridges/ai/streaming_init.go index 28862cb8..1833ec07 100644 --- a/bridges/ai/streaming_init.go +++ b/bridges/ai/streaming_init.go @@ -23,7 +23,7 @@ func (oc *AIClient) createStreamingTurn( sourceEventID id.EventID, senderID string, ) *bridgesdk.Turn { - var sdkConfig *bridgesdk.Config + var sdkConfig *bridgesdk.Config[*AIClient, *Config] if oc.connector != nil { sdkConfig = oc.connector.sdkConfig } diff --git a/bridges/ai/streaming_output_items_test.go b/bridges/ai/streaming_output_items_test.go index 0ebac79b..80d6c293 100644 --- a/bridges/ai/streaming_output_items_test.go +++ b/bridges/ai/streaming_output_items_test.go @@ -60,7 +60,7 @@ func TestDeriveToolDescriptorForOutputItem_FunctionCallParsesArgumentsJSON(t *te func TestUpsertActiveToolFromDescriptor_RecreatesNilMapEntry(t *testing.T) { oc := &AIClient{} state := newStreamingState(context.Background(), nil, "") - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) state.turn = conv.StartTurn(context.Background(), nil, nil) activeTools := newStreamToolRegistry() activeTools.byKey[streamToolItemKey("item_123")] = nil diff --git a/bridges/ai/streaming_ui_tools_test.go b/bridges/ai/streaming_ui_tools_test.go index 6d910464..0c59617f 100644 --- a/bridges/ai/streaming_ui_tools_test.go +++ b/bridges/ai/streaming_ui_tools_test.go @@ -46,7 +46,7 @@ func TestRequestTurnApprovalWithoutApprovalFlowReturnsHandle(t *testing.T) { func TestStartStreamingMCPApprovalAutoApprovedEmitsApprovalRequest(t *testing.T) { oc := newTestAIClient("@owner:example.com") state := newStreamingState(context.Background(), nil, "") - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) state.turn = conv.StartTurn(context.Background(), nil, nil) handle, err := oc.startStreamingMCPApproval(context.Background(), nil, state, ToolApprovalParams{ diff --git a/bridges/ai/system_prompts.go b/bridges/ai/system_prompts.go index 04d1a41c..54c63ce3 100644 --- a/bridges/ai/system_prompts.go +++ b/bridges/ai/system_prompts.go @@ -6,6 +6,7 @@ import ( "maunium.net/go/mautrix/bridgev2" + integrationruntime "github.com/beeper/agentremote/pkg/integrations/runtime" runtimeparse "github.com/beeper/agentremote/pkg/runtime" ) @@ -32,10 +33,6 @@ func buildGroupIntro(roomName string, activation string) string { return strings.Join(lines, " ") + " Address the specific sender noted in the message context." } -func buildVerboseSystemHint(_ *PortalMetadata) string { - return "" -} - func buildSessionIdentityHint(portal *bridgev2.Portal, _ *PortalMetadata) string { if portal == nil { return "" @@ -54,10 +51,6 @@ func buildSessionIdentityHint(portal *bridgev2.Portal, _ *PortalMetadata) string return "sessionKey: " + session } -type memoryPromptAugmentor interface { - PromptContextText(ctx context.Context, portal any, meta any) string -} - func (oc *AIClient) buildAdditionalSystemPromptText( ctx context.Context, portal *bridgev2.Portal, @@ -109,12 +102,6 @@ func (oc *AIClient) buildAdditionalSystemPromptCoreText( } } - if meta != nil { - if verboseHint := buildVerboseSystemHint(meta); verboseHint != "" { - out = append(out, verboseHint) - } - } - if accountHint := oc.buildDesktopAccountHintPrompt(ctx); accountHint != "" { out = append(out, accountHint) } @@ -135,9 +122,12 @@ func (oc *AIClient) buildMemoryPromptContextText( return "" } module := oc.integrationModules["memory"] - augmentor, ok := module.(memoryPromptAugmentor) + augmentor, ok := module.(integrationruntime.PromptContextIntegration) if !ok || augmentor == nil { return "" } - return strings.TrimSpace(augmentor.PromptContextText(ctx, portal, meta)) + return strings.TrimSpace(augmentor.PromptContextText(ctx, integrationruntime.PromptScope{ + Portal: portal, + Meta: meta, + })) } diff --git a/bridges/ai/text_files.go b/bridges/ai/text_files.go index c463e72a..3690b1a7 100644 --- a/bridges/ai/text_files.go +++ b/bridges/ai/text_files.go @@ -233,7 +233,7 @@ func (oc *AIClient) downloadPDFFile(ctx context.Context, mediaURL string, encryp return trimmed, truncated, nil } -func buildTextFileMessage(caption string, hasUserCaption bool, filename string, mimeType string, content string, _ bool) string { +func buildTextFileMessage(caption string, hasUserCaption bool, filename string, mimeType string, content string, truncated bool) string { if !hasUserCaption { caption = "" } @@ -247,10 +247,15 @@ func buildTextFileMessage(caption string, hasUserCaption bool, filename string, mimeType = "text/plain" } + truncAttr := "" + if truncated { + truncAttr = " truncated=\"true\"" + } block := fmt.Sprintf( - "\n%s\n", + "\n%s\n", xmlEscapeAttr(filename), xmlEscapeAttr(mimeType), + truncAttr, escapeFileBlockContent(content), ) diff --git a/bridges/codex/connector.go b/bridges/codex/connector.go index ff52d318..1dfc66f3 100644 --- a/bridges/codex/connector.go +++ b/bridges/codex/connector.go @@ -29,7 +29,7 @@ type CodexConnector struct { *agentremote.ConnectorBase br *bridgev2.Bridge Config Config - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*CodexClient, *Config] db *dbutil.Database clientsMu sync.Mutex diff --git a/bridges/codex/constructors.go b/bridges/codex/constructors.go index d73d89ae..8e9cf395 100644 --- a/bridges/codex/constructors.go +++ b/bridges/codex/constructors.go @@ -33,7 +33,7 @@ func NewConnector() *CodexConnector { Description: "Provide externally managed ChatGPT id/access tokens.", }, } - cc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + cc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*CodexClient, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "codex", Description: "A Matrix↔Codex bridge built on mautrix-go bridgev2.", ProtocolID: "ai-codex", @@ -76,10 +76,10 @@ func NewConnector() *CodexConnector { ExampleConfig: exampleNetworkConfig, ConfigData: &cc.Config, ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, AcceptLogin: func(login *bridgev2.UserLogin) (bool, string) { return bridgesdk.AcceptProviderLogin(login, ProviderCodex, "This bridge only supports Codex logins.", cc.codexEnabled, "Codex integration is disabled in the configuration.", func(login *bridgev2.UserLogin) string { return loginMetadata(login).Provider diff --git a/bridges/codex/stream_mapping_test.go b/bridges/codex/stream_mapping_test.go index 74ca2150..ed1c3393 100644 --- a/bridges/codex/stream_mapping_test.go +++ b/bridges/codex/stream_mapping_test.go @@ -23,7 +23,7 @@ func attachTestTurn(state *streamingState, portal *bridgev2.Portal) { if state == nil { return } - conv := bridgesdk.NewConversation(context.Background(), nil, portal, bridgev2.EventSender{}, &bridgesdk.Config{}, nil) + conv := bridgesdk.NewConversation(context.Background(), nil, portal, bridgev2.EventSender{}, &bridgesdk.Config[*CodexClient, *struct{}]{}, nil) turn := conv.StartTurn(context.Background(), nil, nil) turn.SetID(state.turnID) state.turn = turn diff --git a/bridges/dummybridge/bridge.go b/bridges/dummybridge/bridge.go index c59c0e1d..e9130798 100644 --- a/bridges/dummybridge/bridge.go +++ b/bridges/dummybridge/bridge.go @@ -31,15 +31,14 @@ func (dc *DummyBridgeConnector) loggerForLogin(login *bridgev2.UserLogin) zerolo return login.Log.With().Str("component", "dummybridge").Logger() } -func sessionFromAny(session any) (*dummySession, error) { - dummy, ok := session.(*dummySession) - if !ok || dummy == nil || dummy.login == nil { +func requireSession(session *dummySession) (*dummySession, error) { + if session == nil || session.login == nil { return nil, errors.New("dummybridge session is unavailable") } - return dummy, nil + return session, nil } -func (dc *DummyBridgeConnector) onConnect(ctx context.Context, info *bridgesdk.LoginInfo) (any, error) { +func (dc *DummyBridgeConnector) onConnect(ctx context.Context, info *bridgesdk.LoginInfo) (*dummySession, error) { if info == nil || info.Login == nil { return nil, errors.New("missing login info") } @@ -58,12 +57,12 @@ func (dc *DummyBridgeConnector) onConnect(ctx context.Context, info *bridgesdk.L }, nil } -func (dc *DummyBridgeConnector) onDisconnect(session any) { - _, _ = sessionFromAny(session) +func (dc *DummyBridgeConnector) onDisconnect(session *dummySession) { + _, _ = requireSession(session) } -func (dc *DummyBridgeConnector) getContactList(ctx context.Context, session any) ([]*bridgev2.ResolveIdentifierResponse, error) { - dummy, err := sessionFromAny(session) +func (dc *DummyBridgeConnector) getContactList(ctx context.Context, session *dummySession) ([]*bridgev2.ResolveIdentifierResponse, error) { + dummy, err := requireSession(session) if err != nil { return nil, err } @@ -74,8 +73,8 @@ func (dc *DummyBridgeConnector) getContactList(ctx context.Context, session any) return []*bridgev2.ResolveIdentifierResponse{resp}, nil } -func (dc *DummyBridgeConnector) searchUsers(ctx context.Context, session any, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { - dummy, err := sessionFromAny(session) +func (dc *DummyBridgeConnector) searchUsers(ctx context.Context, session *dummySession, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { + dummy, err := requireSession(session) if err != nil { return nil, err } @@ -99,8 +98,8 @@ func (dc *DummyBridgeConnector) searchUsers(ctx context.Context, session any, qu return nil, nil } -func (dc *DummyBridgeConnector) resolveIdentifier(ctx context.Context, session any, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { - dummy, err := sessionFromAny(session) +func (dc *DummyBridgeConnector) resolveIdentifier(ctx context.Context, session *dummySession, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { + dummy, err := requireSession(session) if err != nil { return nil, err } diff --git a/bridges/dummybridge/connector.go b/bridges/dummybridge/connector.go index d32adad9..e90e139b 100644 --- a/bridges/dummybridge/connector.go +++ b/bridges/dummybridge/connector.go @@ -21,7 +21,7 @@ type DummyBridgeConnector struct { *agentremote.ConnectorBase br *bridgev2.Bridge Config Config - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*dummySession, *Config] clientsMu sync.Mutex clients map[networkid.UserLoginID]bridgev2.NetworkAPI @@ -31,7 +31,7 @@ type DummyBridgeConnector struct { func NewConnector() *DummyBridgeConnector { dc := &DummyBridgeConnector{} - dc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + dc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*dummySession, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "dummybridge", Description: "A synthetic Matrix↔DummyBridge demo bridge built on the AgentRemote SDK.", ProtocolID: "ai-dummybridge", @@ -57,10 +57,10 @@ func NewConnector() *DummyBridgeConnector { ExampleConfig: exampleNetworkConfig, ConfigData: &dc.Config, ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, AcceptLogin: func(login *bridgev2.UserLogin) (bool, string) { return bridgesdk.AcceptProviderLogin(login, ProviderDummyBridge, "This bridge only supports DummyBridge logins.", dc.enabled, "DummyBridge integration is disabled in the configuration.", func(login *bridgev2.UserLogin) string { return loginMetadata(login).Provider diff --git a/bridges/dummybridge/runtime.go b/bridges/dummybridge/runtime.go index 13021b2e..c01c43e3 100644 --- a/bridges/dummybridge/runtime.go +++ b/bridges/dummybridge/runtime.go @@ -247,7 +247,7 @@ const ( randomActionTransient randomActionKind = "data_transient" ) -func (dc *DummyBridgeConnector) onMessage(session any, conv *bridgesdk.Conversation, msg *bridgesdk.Message, turn *bridgesdk.Turn) error { +func (dc *DummyBridgeConnector) onMessage(session *dummySession, conv *bridgesdk.Conversation, msg *bridgesdk.Message, turn *bridgesdk.Turn) error { if conv == nil || turn == nil || msg == nil { return nil } @@ -265,7 +265,7 @@ func (dc *DummyBridgeConnector) onMessage(session any, conv *bridgesdk.Conversat if cmd.Name == "help" { return conv.SendNotice(turn.Context(), helpText()) } - dummy, err := sessionFromAny(session) + dummy, err := requireSession(session) if err != nil { return err } diff --git a/bridges/dummybridge/runtime_test.go b/bridges/dummybridge/runtime_test.go index 53b5b1fc..60290bd4 100644 --- a/bridges/dummybridge/runtime_test.go +++ b/bridges/dummybridge/runtime_test.go @@ -56,7 +56,7 @@ func (r *advancingRuntime) sleepFn(_ context.Context, delay time.Duration) error } func newTestTurn() *bridgesdk.Turn { - cfg := &bridgesdk.Config{ + cfg := &bridgesdk.Config[*dummySession, *struct{}]{ ProviderIdentity: bridgesdk.ProviderIdentity{IDPrefix: "dummybridge", StatusNetwork: "dummybridge"}, } // These tests only exercise turn-local streaming behavior. Login/portal are diff --git a/bridges/openclaw/connector.go b/bridges/openclaw/connector.go index a497397d..d6e27867 100644 --- a/bridges/openclaw/connector.go +++ b/bridges/openclaw/connector.go @@ -23,7 +23,7 @@ type OpenClawConnector struct { *agentremote.ConnectorBase br *bridgev2.Bridge Config Config - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*OpenClawClient, *Config] clientsMu sync.Mutex clients map[networkid.UserLoginID]bridgev2.NetworkAPI @@ -41,7 +41,7 @@ type openClawLoginPrefill struct { func NewConnector() *OpenClawConnector { oc := &OpenClawConnector{} - oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*OpenClawClient, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "openclaw", Description: "A Matrix↔OpenClaw bridge built on mautrix-go bridgev2.", ProtocolID: "ai-openclaw", @@ -75,10 +75,10 @@ func NewConnector() *OpenClawConnector { ExampleConfig: exampleNetworkConfig, ConfigData: &oc.Config, ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, NetworkCapabilities: func() *bridgev2.NetworkGeneralCapabilities { caps := agentremote.DefaultNetworkCapabilities() caps.DisappearingMessages = false diff --git a/bridges/openclaw/stream_test.go b/bridges/openclaw/stream_test.go index c03ccfcb..dccca886 100644 --- a/bridges/openclaw/stream_test.go +++ b/bridges/openclaw/stream_test.go @@ -64,7 +64,7 @@ func (testMatrixAPI) GetEvent(context.Context, id.RoomID, id.EventID) (*event.Ev } func newOpenClawTestTurn(turnID string) *bridgesdk.Turn { - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &bridgesdk.Config{}, nil) + conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &bridgesdk.Config[*OpenClawClient, *struct{}]{}, nil) turn := conv.StartTurn(context.Background(), nil, nil) turn.SetID(turnID) return turn @@ -301,7 +301,7 @@ func TestEmitStreamPartSerializesTurnCreation(t *testing.T) { oc := newOpenClawTestClient(map[string]*openClawStreamState{}) oc.UserLogin = &bridgev2.UserLogin{Bridge: &bridgev2.Bridge{Bot: testMatrixAPI{}}} oc.connector = &OpenClawConnector{} - oc.connector.sdkConfig = &bridgesdk.Config{} + oc.connector.sdkConfig = &bridgesdk.Config[*OpenClawClient, *Config]{} original := openClawNewSDKStreamTurn defer func() { openClawNewSDKStreamTurn = original }() diff --git a/bridges/opencode/connector.go b/bridges/opencode/connector.go index 6b44af91..1221566d 100644 --- a/bridges/opencode/connector.go +++ b/bridges/opencode/connector.go @@ -22,7 +22,7 @@ type OpenCodeConnector struct { *agentremote.ConnectorBase br *bridgev2.Bridge Config Config - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*OpenCodeClient, *Config] clientsMu sync.Mutex clients map[networkid.UserLoginID]bridgev2.NetworkAPI @@ -42,7 +42,7 @@ func NewConnector() *OpenCodeConnector { Description: "Let the bridge spawn and manage OpenCode processes for you.", }, } - oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*OpenCodeClient, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "opencode", Description: "A Matrix↔OpenCode bridge built on mautrix-go bridgev2.", ProtocolID: "ai-opencode", @@ -50,7 +50,7 @@ func NewConnector() *OpenCodeConnector { ProviderIdentity: bridgesdk.ProviderIdentity{IDPrefix: "opencode", LogKey: "opencode_msg_id", StatusNetwork: "opencode"}, ClientCacheMu: &oc.clientsMu, ClientCache: &oc.clients, - GetCapabilities: func(session any, _ *bridgesdk.Conversation) *bridgesdk.RoomFeatures { + GetCapabilities: func(_ *OpenCodeClient, _ *bridgesdk.Conversation) *bridgesdk.RoomFeatures { return &bridgesdk.RoomFeatures{Custom: openCodeMatrixRoomFeatures()} }, InitConnector: func(bridge *bridgev2.Bridge) { @@ -72,10 +72,10 @@ func NewConnector() *OpenCodeConnector { ExampleConfig: exampleNetworkConfig, ConfigData: &oc.Config, ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, AcceptLogin: func(login *bridgev2.UserLogin) (bool, string) { return bridgesdk.AcceptProviderLogin(login, ProviderOpenCode, "This bridge only supports OpenCode logins.", oc.openCodeEnabled, "OpenCode integration is disabled in the configuration.", func(login *bridgev2.UserLogin) string { return loginMetadata(login).Provider diff --git a/pkg/integrations/cron/integration.go b/pkg/integrations/cron/integration.go index 74081b73..f9a028a6 100644 --- a/pkg/integrations/cron/integration.go +++ b/pkg/integrations/cron/integration.go @@ -228,7 +228,7 @@ func (i *Integration) buildToolExecDeps(ctx context.Context, scope iruntime.Tool ResolveCreateContext: func() ToolCreateContext { agentID := i.host.DefaultAgentID() if scope.Meta != nil { - if resolved := strings.TrimSpace(i.host.AgentIDFromMeta(scope.Meta)); resolved != "" { + if resolved := strings.TrimSpace(scope.Meta.AgentID()); resolved != "" { agentID = resolved } } @@ -238,7 +238,7 @@ func (i *Integration) buildToolExecDeps(ctx context.Context, scope iruntime.Tool } sourceInternal := false if scope.Meta != nil { - sourceInternal = i.host.IsInternalRoom(scope.Meta) + sourceInternal = scope.Meta.InternalRoom() } return ToolCreateContext{AgentID: agentID, SourceInternal: sourceInternal, SourceRoomID: roomID} }, @@ -281,7 +281,6 @@ func (i *Integration) buildToolExecDeps(ctx context.Context, scope iruntime.Tool func commandScopeToToolScope(scope iruntime.CommandScope) iruntime.ToolScope { return iruntime.ToolScope{ - Client: scope.Client, Portal: scope.Portal, Meta: scope.Meta, } diff --git a/pkg/integrations/memory/integration.go b/pkg/integrations/memory/integration.go index fddfdf7f..e209a835 100644 --- a/pkg/integrations/memory/integration.go +++ b/pkg/integrations/memory/integration.go @@ -9,6 +9,7 @@ import ( "github.com/openai/openai-go/v3" "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/bridgev2" "github.com/beeper/agentremote/pkg/agents" iruntime "github.com/beeper/agentremote/pkg/integrations/runtime" @@ -69,7 +70,7 @@ func (i *Integration) ToolAvailability(_ context.Context, scope iruntime.ToolSco return false, false, iruntime.SourceGlobalDefault, "" } if scope.Meta != nil { - agentID := i.host.AgentIDFromMeta(scope.Meta) + agentID := scope.Meta.AgentID() _, errMsg := i.getManager(agentID) if errMsg != "" { return true, false, iruntime.SourceProviderLimit, errMsg @@ -78,8 +79,8 @@ func (i *Integration) ToolAvailability(_ context.Context, scope iruntime.ToolSco return true, true, iruntime.SourceGlobalDefault, "" } -func (i *Integration) PromptContextText(ctx context.Context, portal any, meta any) string { - return BuildPromptContextText(ctx, portal, meta, PromptContextDeps{ +func (i *Integration) PromptContextText(ctx context.Context, scope iruntime.PromptScope) string { + return BuildPromptContextText(ctx, scope.Portal, scope.Meta, PromptContextDeps{ ShouldInjectContext: i.shouldInjectMemoryPromptContext, ShouldBootstrap: i.shouldBootstrapMemoryPromptContext, ResolveBootstrapPaths: i.resolveMemoryBootstrapPaths, @@ -134,16 +135,16 @@ func (i *Integration) OnCompactionLifecycle(ctx context.Context, evt iruntime.Co } switch evt.Phase { case iruntime.CompactionLifecycleStart: - i.host.SetModuleMeta(evt.Meta, "compaction_in_flight", true) + evt.Meta.SetModuleMetaValue("compaction_in_flight", true) case iruntime.CompactionLifecycleEnd: - i.host.SetModuleMeta(evt.Meta, "compaction_in_flight", false) - i.host.SetModuleMeta(evt.Meta, "last_compaction_at", time.Now().UnixMilli()) - i.host.SetModuleMeta(evt.Meta, "last_compaction_dropped_count", evt.DroppedCount) + evt.Meta.SetModuleMetaValue("compaction_in_flight", false) + evt.Meta.SetModuleMetaValue("last_compaction_at", time.Now().UnixMilli()) + evt.Meta.SetModuleMetaValue("last_compaction_dropped_count", evt.DroppedCount) case iruntime.CompactionLifecycleFail: - i.host.SetModuleMeta(evt.Meta, "compaction_in_flight", false) - i.host.SetModuleMeta(evt.Meta, "last_compaction_error", strings.TrimSpace(evt.Error)) + evt.Meta.SetModuleMetaValue("compaction_in_flight", false) + evt.Meta.SetModuleMetaValue("last_compaction_error", strings.TrimSpace(evt.Error)) case iruntime.CompactionLifecycleRefresh: - i.host.SetModuleMeta(evt.Meta, "last_compaction_refresh_at", time.Now().UnixMilli()) + evt.Meta.SetModuleMetaValue("last_compaction_refresh_at", time.Now().UnixMilli()) } if evt.Portal == nil { return @@ -200,11 +201,6 @@ func (i *Integration) buildCommandExecDeps() CommandExecDeps { } } -func asOverflowCall(call any) (iruntime.ContextOverflowCall, bool) { - oc, ok := call.(iruntime.ContextOverflowCall) - return oc, ok -} - func toInt64(v any) int64 { switch n := v.(type) { case int64: @@ -224,50 +220,36 @@ func (i *Integration) buildOverflowDeps() OverflowDeps { TrimPrompt: func(prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion { return i.host.SmartTruncatePrompt(prompt, 0.5) }, - ContextWindow: func(call any) int { - oc, ok := asOverflowCall(call) - if !ok { - return 128000 - } - return i.host.ContextWindow(oc.Meta) + ContextWindow: func(call iruntime.ContextOverflowCall) int { + return i.host.ContextWindow(call.Meta) }, ReserveTokens: func() int { return i.host.CompactorReserveTokens() }, - EffectiveModel: func(call any) string { - oc, ok := asOverflowCall(call) - if !ok { - return "" - } - return i.host.EffectiveModel(oc.Meta) + EffectiveModel: func(call iruntime.ContextOverflowCall) string { + return i.host.EffectiveModel(call.Meta) }, EstimateTokens: func(prompt []openai.ChatCompletionMessageParamUnion, model string) int { return i.host.EstimateTokens(prompt, model) }, - AlreadyFlushed: func(call any) bool { - oc, ok := asOverflowCall(call) - if !ok { - return false - } - flushAtMs := toInt64(i.host.GetModuleMeta(oc.Meta, "overflow_flush_at")) + AlreadyFlushed: func(call iruntime.ContextOverflowCall) bool { + flushAtMs := toInt64(call.Meta.ModuleMetaValue("overflow_flush_at")) if flushAtMs == 0 { return false } - flushCC := toInt64(i.host.GetModuleMeta(oc.Meta, "overflow_flush_compaction_count")) - return int(flushCC) == i.host.CompactionCount(oc.Meta) + flushCC := toInt64(call.Meta.ModuleMetaValue("overflow_flush_compaction_count")) + return int(flushCC) == call.Meta.CompactionCounter() }, - MarkFlushed: func(ctx context.Context, call any) { - oc, _ := asOverflowCall(call) - if oc.Portal == nil || oc.Meta == nil { + MarkFlushed: func(ctx context.Context, call iruntime.ContextOverflowCall) { + if call.Portal == nil || call.Meta == nil { return } - i.host.SetModuleMeta(oc.Meta, "overflow_flush_at", time.Now().UnixMilli()) - i.host.SetModuleMeta(oc.Meta, "overflow_flush_compaction_count", i.host.CompactionCount(oc.Meta)) - _ = i.host.SavePortal(ctx, oc.Portal, "overflow flush") + call.Meta.SetModuleMetaValue("overflow_flush_at", time.Now().UnixMilli()) + call.Meta.SetModuleMetaValue("overflow_flush_compaction_count", call.Meta.CompactionCounter()) + _ = i.host.SavePortal(ctx, call.Portal, "overflow flush") }, - RunFlushToolLoop: func(ctx context.Context, call any, model string, prompt []openai.ChatCompletionMessageParamUnion) (bool, error) { - oc, _ := asOverflowCall(call) - return i.runFlushToolLoop(ctx, oc.Portal, oc.Meta, model, prompt) + RunFlushToolLoop: func(ctx context.Context, call iruntime.ContextOverflowCall, model string, prompt []openai.ChatCompletionMessageParamUnion) (bool, error) { + return i.runFlushToolLoop(ctx, call.Portal, call.Meta, model, prompt) }, OnError: func(_ context.Context, err error) { i.host.Logger().Warn("overflow flush failed", map[string]any{"error": err.Error()}) @@ -275,7 +257,7 @@ func (i *Integration) buildOverflowDeps() OverflowDeps { } } -func (i *Integration) shouldInjectMemoryPromptContext(_ any, _ any) bool { +func (i *Integration) shouldInjectMemoryPromptContext(_ *bridgev2.Portal, _ iruntime.Meta) bool { if cfg := i.host.ModuleConfig(moduleName); cfg != nil { inject, _ := cfg["inject_context"].(bool) return inject @@ -283,15 +265,15 @@ func (i *Integration) shouldInjectMemoryPromptContext(_ any, _ any) bool { return false } -func (i *Integration) shouldBootstrapMemoryPromptContext(_ any, meta any) bool { - raw := i.host.GetModuleMeta(meta, "memory_bootstrap_at") +func (i *Integration) shouldBootstrapMemoryPromptContext(_ *bridgev2.Portal, meta iruntime.Meta) bool { + raw := meta.ModuleMetaValue("memory_bootstrap_at") if raw == nil { return true } return toInt64(raw) == 0 } -func (i *Integration) resolveMemoryBootstrapPaths(_ any, _ any) []string { +func (i *Integration) resolveMemoryBootstrapPaths(_ *bridgev2.Portal, _ iruntime.Meta) []string { _, loc := i.host.UserTimezone() if loc == nil { loc = time.UTC @@ -305,18 +287,18 @@ func (i *Integration) resolveMemoryBootstrapPaths(_ any, _ any) []string { } } -func (i *Integration) markMemoryPromptBootstrapped(ctx context.Context, portal any, meta any) { +func (i *Integration) markMemoryPromptBootstrapped(ctx context.Context, portal *bridgev2.Portal, meta iruntime.Meta) { if portal == nil || meta == nil { return } - i.host.SetModuleMeta(meta, "memory_bootstrap_at", time.Now().UnixMilli()) + meta.SetModuleMetaValue("memory_bootstrap_at", time.Now().UnixMilli()) _ = i.host.SavePortal(ctx, portal, "memory bootstrap") } -func (i *Integration) readMemoryPromptSection(ctx context.Context, meta any, path string) string { +func (i *Integration) readMemoryPromptSection(ctx context.Context, meta iruntime.Meta, path string) string { agentID := "" if meta != nil { - agentID = i.host.AgentIDFromMeta(meta) + agentID = meta.AgentID() } content, filePath, found, err := i.host.ReadTextFile(ctx, agentID, path) if err != nil || !found { @@ -354,8 +336,8 @@ func (i *Integration) getManager(agentID string) (*MemorySearchManager, string) func (i *Integration) runFlushToolLoop( ctx context.Context, - portal any, - meta any, + portal *bridgev2.Portal, + meta iruntime.Meta, model string, messages []openai.ChatCompletionMessageParamUnion, ) (bool, error) { @@ -458,26 +440,21 @@ func (i *Integration) writeMemoryCommandFile( ) (string, error) { agentID := "" if scope.Meta != nil { - agentID = i.host.AgentIDFromMeta(scope.Meta) + agentID = scope.Meta.AgentID() } return i.host.WriteTextFile(ctx, scope.Portal, scope.Meta, agentID, mode, path, content, maxBytes) } -func (i *Integration) agentIDFromEventMeta(meta any) string { +func (i *Integration) agentIDFromEventMeta(meta iruntime.Meta) string { var rawAgentID string if meta != nil { - rawAgentID = i.host.AgentIDFromMeta(meta) + rawAgentID = meta.AgentID() } return i.host.ResolveAgentID(rawAgentID, i.host.DefaultAgentID()) } func (i *Integration) resolveBridgeDB() *dbutil.Database { - raw := i.host.BridgeDB() - if raw == nil { - return nil - } - db, _ := raw.(*dbutil.Database) - return db + return i.host.BridgeDB() } // splitQuotedArgs parses a raw argument string into tokens, respecting quoted segments. diff --git a/pkg/integrations/memory/manager.go b/pkg/integrations/memory/manager.go index 8f96e561..36705077 100644 --- a/pkg/integrations/memory/manager.go +++ b/pkg/integrations/memory/manager.go @@ -115,11 +115,7 @@ func GetMemorySearchManager(host iruntime.Host, agentID string) (*MemorySearchMa if host == nil { return nil, "memory search unavailable" } - rawDB := host.BridgeDB() - if rawDB == nil { - return nil, "memory search unavailable" - } - db, _ := rawDB.(*dbutil.Database) + db := host.BridgeDB() if db == nil { return nil, "memory search unavailable" } diff --git a/pkg/integrations/memory/module_exec.go b/pkg/integrations/memory/module_exec.go index 5eb2eb0f..9e7142e3 100644 --- a/pkg/integrations/memory/module_exec.go +++ b/pkg/integrations/memory/module_exec.go @@ -196,7 +196,6 @@ func ExecuteCommand(ctx context.Context, call iruntime.CommandCall, deps Command } action := strings.ToLower(strings.TrimSpace(call.Args[0])) scope := iruntime.ToolScope{ - Client: call.Scope.Client, Portal: call.Scope.Portal, Meta: call.Scope.Meta, } diff --git a/pkg/integrations/memory/overflow_exec.go b/pkg/integrations/memory/overflow_exec.go index dec9ccfb..5db54a74 100644 --- a/pkg/integrations/memory/overflow_exec.go +++ b/pkg/integrations/memory/overflow_exec.go @@ -22,19 +22,19 @@ type FlushSettings struct { type OverflowDeps struct { ResolveSettings func() *FlushSettings TrimPrompt func(prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion - ContextWindow func(call any) int + ContextWindow func(call iruntime.ContextOverflowCall) int ReserveTokens func() int - EffectiveModel func(call any) string + EffectiveModel func(call iruntime.ContextOverflowCall) string EstimateTokens func(prompt []openai.ChatCompletionMessageParamUnion, model string) int - AlreadyFlushed func(call any) bool - MarkFlushed func(ctx context.Context, call any) - RunFlushToolLoop func(ctx context.Context, call any, model string, prompt []openai.ChatCompletionMessageParamUnion) (bool, error) + AlreadyFlushed func(call iruntime.ContextOverflowCall) bool + MarkFlushed func(ctx context.Context, call iruntime.ContextOverflowCall) + RunFlushToolLoop func(ctx context.Context, call iruntime.ContextOverflowCall, model string, prompt []openai.ChatCompletionMessageParamUnion) (bool, error) OnError func(ctx context.Context, err error) } func HandleOverflow( ctx context.Context, - call any, + call iruntime.ContextOverflowCall, prompt []openai.ChatCompletionMessageParamUnion, deps OverflowDeps, ) { @@ -64,8 +64,8 @@ func HandleOverflow( model = deps.EffectiveModel(call) } totalTokens := 0 - if overflowCall, ok := call.(iruntime.ContextOverflowCall); ok && overflowCall.RequestedTokens > 0 { - totalTokens = overflowCall.RequestedTokens + if call.RequestedTokens > 0 { + totalTokens = call.RequestedTokens } if totalTokens <= 0 && deps.EstimateTokens != nil { totalTokens = deps.EstimateTokens(prompt, model) @@ -102,8 +102,8 @@ func HandleOverflow( func shouldRunFlush( totalTokens, contextWindow, reserveTokens int, settings *FlushSettings, - alreadyFlushed func(call any) bool, - call any, + alreadyFlushed func(call iruntime.ContextOverflowCall) bool, + call iruntime.ContextOverflowCall, ) bool { if settings == nil { return false diff --git a/pkg/integrations/memory/prompt_exec.go b/pkg/integrations/memory/prompt_exec.go index 8ad0bc63..bda520f0 100644 --- a/pkg/integrations/memory/prompt_exec.go +++ b/pkg/integrations/memory/prompt_exec.go @@ -3,20 +3,24 @@ package memory import ( "context" "strings" + + "maunium.net/go/mautrix/bridgev2" + + iruntime "github.com/beeper/agentremote/pkg/integrations/runtime" ) type PromptContextDeps struct { - ShouldInjectContext func(portal any, meta any) bool - ShouldBootstrap func(portal any, meta any) bool - ResolveBootstrapPaths func(portal any, meta any) []string - MarkBootstrapped func(ctx context.Context, portal any, meta any) - ReadSection func(ctx context.Context, meta any, path string) string + ShouldInjectContext func(portal *bridgev2.Portal, meta iruntime.Meta) bool + ShouldBootstrap func(portal *bridgev2.Portal, meta iruntime.Meta) bool + ResolveBootstrapPaths func(portal *bridgev2.Portal, meta iruntime.Meta) []string + MarkBootstrapped func(ctx context.Context, portal *bridgev2.Portal, meta iruntime.Meta) + ReadSection func(ctx context.Context, meta iruntime.Meta, path string) string } func BuildPromptContextText( ctx context.Context, - portal any, - meta any, + portal *bridgev2.Portal, + meta iruntime.Meta, deps PromptContextDeps, ) string { if deps.ShouldInjectContext == nil || !deps.ShouldInjectContext(portal, meta) { diff --git a/pkg/integrations/memory/sessions.go b/pkg/integrations/memory/sessions.go index 229dd206..02b340cd 100644 --- a/pkg/integrations/memory/sessions.go +++ b/pkg/integrations/memory/sessions.go @@ -39,11 +39,7 @@ func (m *MemorySearchManager) activeSessionPortals(ctx context.Context) (map[str if key == "" { continue } - portalKey, ok := info.PortalKey.(networkid.PortalKey) - if !ok { - continue - } - active[key] = sessionPortal{key: key, portalKey: portalKey} + active[key] = sessionPortal{key: key, portalKey: info.PortalKey} } return active, nil } diff --git a/pkg/integrations/runtime/helpers.go b/pkg/integrations/runtime/helpers.go index bcf5f733..8923537b 100644 --- a/pkg/integrations/runtime/helpers.go +++ b/pkg/integrations/runtime/helpers.go @@ -12,10 +12,7 @@ func ZerologFromHost(host Host) zerolog.Logger { if host == nil { return zerolog.Nop() } - if zl, ok := host.RawLogger().(zerolog.Logger); ok { - return zl - } - return zerolog.Nop() + return host.RawLogger() } // ModuleOrNil returns nil when the host is absent, otherwise it constructs the module. diff --git a/pkg/integrations/runtime/host_types.go b/pkg/integrations/runtime/host_types.go index 29736c4c..aa450783 100644 --- a/pkg/integrations/runtime/host_types.go +++ b/pkg/integrations/runtime/host_types.go @@ -1,6 +1,19 @@ package runtime -import "github.com/openai/openai-go/v3" +import ( + "maunium.net/go/mautrix/bridgev2/networkid" + + "github.com/openai/openai-go/v3" +) + +// Meta describes the portal metadata behavior integration modules depend on. +type Meta interface { + ModuleMetaValue(key string) any + SetModuleMetaValue(key string, value any) + AgentID() string + CompactionCounter() int + InternalRoom() bool +} // MessageSummary is a generic message summary. type MessageSummary struct { @@ -33,5 +46,5 @@ type CompletionResult struct { // SessionPortalInfo is a generic portal reference for session listing. type SessionPortalInfo struct { Key string - PortalKey any + PortalKey networkid.PortalKey } diff --git a/pkg/integrations/runtime/interfaces.go b/pkg/integrations/runtime/interfaces.go index 61d0c527..88b90570 100644 --- a/pkg/integrations/runtime/interfaces.go +++ b/pkg/integrations/runtime/interfaces.go @@ -2,6 +2,8 @@ package runtime import ( "context" + + "maunium.net/go/mautrix/bridgev2" ) // SettingSource indicates where a setting value came from. @@ -27,9 +29,8 @@ type ToolDefinition struct { // ToolScope carries integration context without coupling to connector internals. type ToolScope struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta } // ToolCall is a concrete tool execution request. @@ -54,6 +55,18 @@ type ToolApprovalIntegration interface { ToolApprovalRequirement(toolName string, args map[string]any) (handled bool, required bool, action string) } +// PromptScope carries typed prompt-building context. +type PromptScope struct { + Portal *bridgev2.Portal + Meta Meta +} + +// PromptContextIntegration contributes additional prompt context text. +type PromptContextIntegration interface { + Name() string + PromptContextText(ctx context.Context, scope PromptScope) string +} + // LifecycleIntegration is an optional capability for integrations that need runtime start/stop hooks. type LifecycleIntegration interface { Start(ctx context.Context) error diff --git a/pkg/integrations/runtime/module_hooks.go b/pkg/integrations/runtime/module_hooks.go index c4d1df31..175b7a3d 100644 --- a/pkg/integrations/runtime/module_hooks.go +++ b/pkg/integrations/runtime/module_hooks.go @@ -5,6 +5,9 @@ import ( "time" "github.com/openai/openai-go/v3" + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/bridgev2" ) // ModuleHooks is the base contract every integration module implements. @@ -28,10 +31,8 @@ type CommandDefinition struct { // CommandScope carries command execution context without importing connector internals. type CommandScope struct { - Client any - Portal any - Meta any - Event any + Portal *bridgev2.Portal + Meta Meta } // CommandCall is a concrete command execution request. @@ -63,9 +64,8 @@ const ( // SessionMutationEvent is emitted when chat/session data changes. type SessionMutationEvent struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta SessionKey string Force bool Kind SessionMutationKind @@ -73,9 +73,8 @@ type SessionMutationEvent struct { // FileChangedEvent is emitted when a file write/edit/apply_patch updates workspace data. type FileChangedEvent struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta Path string } @@ -99,9 +98,8 @@ const ( // CompactionLifecycleEvent provides compaction lifecycle details to integrations. type CompactionLifecycleEvent struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta Phase CompactionLifecyclePhase Attempt int ContextWindowTokens int @@ -125,9 +123,8 @@ type CompactionLifecycleIntegration interface { // ContextOverflowCall contains context-overflow retry state. type ContextOverflowCall struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta Prompt []openai.ChatCompletionMessageParamUnion RequestedTokens int ModelMaxTokens int @@ -136,8 +133,6 @@ type ContextOverflowCall struct { // LoginScope carries per-login cleanup scope. type LoginScope struct { - Client any - Login any BridgeID string LoginID string } @@ -153,74 +148,40 @@ type LoginPurgeIntegration interface { // nested capability objects or type-asserting optional host adapters. type Host interface { Logger() Logger - RawLogger() any + RawLogger() zerolog.Logger Now() time.Time - ResolvePortalByRoomID(ctx context.Context, roomID string) any - ResolveDefaultPortal(ctx context.Context) any - ResolveLastActivePortal(ctx context.Context, agentID string) any - DispatchInternalMessage(ctx context.Context, portal any, meta any, message string, source string) error - SendAssistantMessage(ctx context.Context, portal any, body string) error - RequestNow(ctx context.Context, reason string) - ToolDefinitionByName(name string) (ToolDefinition, bool) - ExecuteBuiltinTool(ctx context.Context, scope ToolScope, name string, rawArgsJSON string) (string, error) ResolveWorkspaceDir() string - BridgeDB() any + BridgeDB() *dbutil.Database BridgeID() string LoginID() string ModuleEnabled(name string) bool ModuleConfig(name string) map[string]any AgentModuleConfig(agentID string, module string) map[string]any - GetOrCreatePortal(ctx context.Context, portalID string, receiver string, displayName string, setupMeta func(meta any)) (portal any, roomID string, err error) - SavePortal(ctx context.Context, portal any, reason string) error - PortalRoomID(portal any) string - PortalKeyString(portal any) string - - GetModuleMeta(meta any, key string) any - SetModuleMeta(meta any, key string, value any) - AgentIDFromMeta(meta any) string - CompactionCount(meta any) int - IsGroupChat(ctx context.Context, portal any) bool - IsInternalRoom(meta any) bool - PortalMeta(portal any) any - CloneMeta(portal any) any - SetMetaField(meta any, key string, value any) - - RecentMessages(ctx context.Context, portal any, count int) []MessageSummary - LastAssistantMessage(ctx context.Context, portal any) (id string, timestamp int64) - WaitForAssistantMessage(ctx context.Context, portal any, afterID string, afterTS int64) (*AssistantMessageInfo, bool) - - RunHeartbeatOnce(ctx context.Context, reason string) (status string, reasonMsg string) - ResolveHeartbeatSessionPortal(agentID string) (portal any, sessionKey string, err error) - ResolveHeartbeatSessionKey(agentID string) string - HeartbeatAckMaxChars(agentID string) int - EnqueueSystemEvent(sessionKey string, text string, agentID string) - PersistSystemEvents() - ResolveLastTarget(agentID string) (channel string, target string, ok bool) + SavePortal(ctx context.Context, portal *bridgev2.Portal, reason string) error + PortalRoomID(portal *bridgev2.Portal) string + PortalKeyString(portal *bridgev2.Portal) string + + IsGroupChat(ctx context.Context, portal *bridgev2.Portal) bool + + RecentMessages(ctx context.Context, portal *bridgev2.Portal, count int) []MessageSummary ResolveAgentID(raw string, fallbackDefault string) string - NormalizeAgentID(raw string) string - AgentExists(normalizedID string) bool DefaultAgentID() string - AgentTimeoutSeconds() int UserTimezone() (tz string, loc *time.Location) - NormalizeThinkingLevel(raw string) (string, bool) - - EffectiveModel(meta any) string - ContextWindow(meta any) int - MergeDisconnectContext(ctx context.Context) (context.Context, context.CancelFunc) - BackgroundContext(ctx context.Context) context.Context + EffectiveModel(meta Meta) string + ContextWindow(meta Meta) int - NewCompletion(ctx context.Context, model string, messages []openai.ChatCompletionMessageParamUnion, toolParams any) (*CompletionResult, error) + NewCompletion(ctx context.Context, model string, messages []openai.ChatCompletionMessageParamUnion, toolParams []openai.ChatCompletionToolUnionParam) (*CompletionResult, error) - IsToolEnabled(meta any, toolName string) bool + IsToolEnabled(meta Meta, toolName string) bool AllToolDefinitions() []ToolDefinition - ExecuteToolInContext(ctx context.Context, portal any, meta any, name string, argsJSON string) (string, error) - ToolsToOpenAIParams(tools []ToolDefinition) any + ExecuteToolInContext(ctx context.Context, portal *bridgev2.Portal, meta Meta, name string, argsJSON string) (string, error) + ToolsToOpenAIParams(tools []ToolDefinition) []openai.ChatCompletionToolUnionParam ReadTextFile(ctx context.Context, agentID string, path string) (content string, filePath string, found bool, err error) - WriteTextFile(ctx context.Context, portal any, meta any, agentID string, mode string, path string, content string, maxBytes int) (finalPath string, err error) + WriteTextFile(ctx context.Context, portal *bridgev2.Portal, meta Meta, agentID string, mode string, path string, content string, maxBytes int) (finalPath string, err error) SmartTruncatePrompt(prompt []openai.ChatCompletionMessageParamUnion, ratio float64) []openai.ChatCompletionMessageParamUnion EstimateTokens(prompt []openai.ChatCompletionMessageParamUnion, model string) int @@ -230,7 +191,7 @@ type Host interface { IsLoggedIn() bool SessionPortals(ctx context.Context, loginID string, agentID string) ([]SessionPortalInfo, error) - LoginDB() any + LoginDB() *dbutil.Database } // Logger is a minimal structured logger abstraction. diff --git a/sdk/client.go b/sdk/client.go index 14f9da5d..67d4fa5b 100644 --- a/sdk/client.go +++ b/sdk/client.go @@ -17,18 +17,18 @@ import ( // Compile-time interface checks. var ( - _ bridgev2.NetworkAPI = (*sdkClient)(nil) - _ bridgev2.EditHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.ReactionHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.RedactionHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.TypingHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.RoomNameHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.RoomTopicHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.BackfillingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.DeleteChatHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.IdentifierResolvingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.ContactListingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.UserSearchingNetworkAPI = (*sdkClient)(nil) + _ bridgev2.NetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.EditHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.ReactionHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.RedactionHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.TypingHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.RoomNameHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.RoomTopicHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.BackfillingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.DeleteChatHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.IdentifierResolvingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.ContactListingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.UserSearchingNetworkAPI = (*sdkClient[any, any])(nil) ) // pendingSDKApprovalData holds SDK-specific metadata for a pending tool approval. @@ -39,19 +39,19 @@ type pendingSDKApprovalData struct { ToolName string } -type sdkClient struct { +type sdkClient[SessionT SessionValue, ConfigDataT ConfigValue] struct { agentremote.ClientBase - cfg *Config + cfg *Config[SessionT, ConfigDataT] userLogin *bridgev2.UserLogin approvalFlow *agentremote.ApprovalFlow[*pendingSDKApprovalData] turnManager *TurnManager conversationState *conversationStateStore sessionMu sync.RWMutex - session any + session SessionT } -func newSDKClient(login *bridgev2.UserLogin, cfg *Config) *sdkClient { +func newSDKClient[SessionT SessionValue, ConfigDataT ConfigValue](login *bridgev2.UserLogin, cfg *Config[SessionT, ConfigDataT]) *sdkClient[SessionT, ConfigDataT] { identity := resolveProviderIdentity(cfg) senderForPortal := func(*bridgev2.Portal) bridgev2.EventSender { if cfg != nil && cfg.Agent != nil { @@ -59,7 +59,7 @@ func newSDKClient(login *bridgev2.UserLogin, cfg *Config) *sdkClient { } return bridgev2.EventSender{} } - c := &sdkClient{ + c := &sdkClient[SessionT, ConfigDataT]{ cfg: cfg, userLogin: login, conversationState: newConversationStateStore(), @@ -86,44 +86,82 @@ func newSDKClient(login *bridgev2.UserLogin, cfg *Config) *sdkClient { return c } -func (c *sdkClient) GetApprovalHandler() agentremote.ApprovalReactionHandler { +func (c *sdkClient[SessionT, ConfigDataT]) GetApprovalHandler() agentremote.ApprovalReactionHandler { return c.approvalFlow } -func (c *sdkClient) config() *Config { return c.cfg } +func (c *sdkClient[SessionT, ConfigDataT]) agent() *Agent { + if c == nil || c.cfg == nil { + return nil + } + return c.cfg.Agent +} -func (c *sdkClient) sessionValue() any { return c.getSession() } +func (c *sdkClient[SessionT, ConfigDataT]) agentCatalog() AgentCatalog { + if c == nil || c.cfg == nil { + return nil + } + return c.cfg.AgentCatalog +} -func (c *sdkClient) conversationStore() *conversationStateStore { return c.conversationState } +func (c *sdkClient[SessionT, ConfigDataT]) roomFeatures(conv *Conversation) *RoomFeatures { + if c == nil || c.cfg == nil { + return nil + } + if c.cfg.GetCapabilities != nil { + if rf := c.cfg.GetCapabilities(c.getSession(), conv); rf != nil { + return rf + } + } + return c.cfg.RoomFeatures +} + +func (c *sdkClient[SessionT, ConfigDataT]) commands() []Command { + if c == nil || c.cfg == nil { + return nil + } + return c.cfg.Commands +} + +func (c *sdkClient[SessionT, ConfigDataT]) turnConfig() *TurnConfig { + if c == nil || c.cfg == nil { + return nil + } + return c.cfg.TurnManagement +} + +func (c *sdkClient[SessionT, ConfigDataT]) conversationStore() *conversationStateStore { + return c.conversationState +} -func (c *sdkClient) approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] { +func (c *sdkClient[SessionT, ConfigDataT]) approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] { return c.approvalFlow } -func (c *sdkClient) providerIdentity() ProviderIdentity { +func (c *sdkClient[SessionT, ConfigDataT]) providerIdentity() ProviderIdentity { return resolveProviderIdentity(c.cfg) } -func (c *sdkClient) getSession() any { +func (c *sdkClient[SessionT, ConfigDataT]) getSession() SessionT { c.sessionMu.RLock() defer c.sessionMu.RUnlock() return c.session } -func (c *sdkClient) setSession(s any) { +func (c *sdkClient[SessionT, ConfigDataT]) setSession(s SessionT) { c.sessionMu.Lock() c.session = s c.sessionMu.Unlock() } // Connect implements bridgev2.NetworkAPI. -func (c *sdkClient) Connect(ctx context.Context) { - if c.config().OnConnect != nil { +func (c *sdkClient[SessionT, ConfigDataT]) Connect(ctx context.Context) { + if c.cfg != nil && c.cfg.OnConnect != nil { info := &LoginInfo{ Login: c.userLogin, UserID: string(c.userLogin.UserMXID), } - session, err := c.config().OnConnect(ctx, info) + session, err := c.cfg.OnConnect(ctx, info) if err != nil { c.userLogin.BridgeState.Send(status.BridgeState{ StateEvent: status.StateUnknownError, @@ -137,55 +175,56 @@ func (c *sdkClient) Connect(ctx context.Context) { c.userLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) } -func (c *sdkClient) Disconnect() { +func (c *sdkClient[SessionT, ConfigDataT]) Disconnect() { c.SetLoggedIn(false) if c.approvalFlow != nil { c.approvalFlow.Close() } c.CloseAllSessions() - if c.config().OnDisconnect != nil { - c.config().OnDisconnect(c.getSession()) + if c.cfg != nil && c.cfg.OnDisconnect != nil { + c.cfg.OnDisconnect(c.getSession()) } - c.setSession(nil) + var zero SessionT + c.setSession(zero) } -func (c *sdkClient) LogoutRemote(ctx context.Context) { +func (c *sdkClient[SessionT, ConfigDataT]) LogoutRemote(ctx context.Context) { c.Disconnect() } -func (c *sdkClient) IsThisUser(_ context.Context, userID networkid.UserID) bool { - if c.config().IsThisUser != nil { - return c.config().IsThisUser(string(userID)) +func (c *sdkClient[SessionT, ConfigDataT]) IsThisUser(_ context.Context, userID networkid.UserID) bool { + if c.cfg != nil && c.cfg.IsThisUser != nil { + return c.cfg.IsThisUser(string(userID)) } return false } -func (c *sdkClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { - if c.config().GetChatInfo != nil { - return c.config().GetChatInfo(c.conv(ctx, portal)) +func (c *sdkClient[SessionT, ConfigDataT]) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { + if c.cfg != nil && c.cfg.GetChatInfo != nil { + return c.cfg.GetChatInfo(c.conv(ctx, portal)) } return nil, nil } -func (c *sdkClient) GetUserInfo(_ context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { - if c.config().GetUserInfo != nil { - return c.config().GetUserInfo(ghost) +func (c *sdkClient[SessionT, ConfigDataT]) GetUserInfo(_ context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { + if c.cfg != nil && c.cfg.GetUserInfo != nil { + return c.cfg.GetUserInfo(ghost) } return nil, nil } -func (c *sdkClient) GetCapabilities(_ context.Context, portal *bridgev2.Portal) *event.RoomFeatures { +func (c *sdkClient[SessionT, ConfigDataT]) GetCapabilities(_ context.Context, portal *bridgev2.Portal) *event.RoomFeatures { conv := c.conv(context.Background(), portal) return convertRoomFeatures(conv.currentRoomFeatures(context.Background())) } -func (c *sdkClient) conv(ctx context.Context, portal *bridgev2.Portal) *Conversation { +func (c *sdkClient[SessionT, ConfigDataT]) conv(ctx context.Context, portal *bridgev2.Portal) *Conversation { return newConversation(ctx, portal, c.userLogin, bridgev2.EventSender{}, c) } // HandleMatrixMessage dispatches incoming messages to the OnMessage callback. -func (c *sdkClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (*bridgev2.MatrixMessageResponse, error) { - if c.config().OnMessage == nil { +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (*bridgev2.MatrixMessageResponse, error) { + if c.cfg == nil || c.cfg.OnMessage == nil { return nil, nil } runCtx := c.BackgroundContext(ctx) @@ -203,7 +242,7 @@ func (c *sdkClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri roomID = c.turnManager.ResolveKey(roomID) } run := func(turnCtx context.Context) error { - return c.config().OnMessage(session, conv, sdkMsg, turn) + return c.cfg.OnMessage(session, conv, sdkMsg, turn) } go func() { var err error @@ -272,8 +311,8 @@ func convertMatrixMessage(msg *bridgev2.MatrixMessage) *Message { } // HandleMatrixEdit implements bridgev2.EditHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixEdit(ctx context.Context, edit *bridgev2.MatrixEdit) error { - if c.config().OnEdit == nil { +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixEdit(ctx context.Context, edit *bridgev2.MatrixEdit) error { + if c.cfg == nil || c.cfg.OnEdit == nil { return nil } me := &MessageEdit{ @@ -284,81 +323,81 @@ func (c *sdkClient) HandleMatrixEdit(ctx context.Context, edit *bridgev2.MatrixE me.NewText = edit.Content.Body me.NewHTML = edit.Content.FormattedBody } - return c.config().OnEdit(c.getSession(), c.conv(ctx, edit.Portal), me) + return c.cfg.OnEdit(c.getSession(), c.conv(ctx, edit.Portal), me) } // HandleMatrixMessageRemove implements bridgev2.RedactionHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixMessageRemove(ctx context.Context, msg *bridgev2.MatrixMessageRemove) error { - if c.config().OnDelete == nil { +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixMessageRemove(ctx context.Context, msg *bridgev2.MatrixMessageRemove) error { + if c.cfg == nil || c.cfg.OnDelete == nil { return nil } var msgID string if msg.TargetMessage != nil { msgID = string(msg.TargetMessage.ID) } - return c.config().OnDelete(c.getSession(), c.conv(ctx, msg.Portal), msgID) + return c.cfg.OnDelete(c.getSession(), c.conv(ctx, msg.Portal), msgID) } // HandleMatrixTyping implements bridgev2.TypingHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixTyping(ctx context.Context, msg *bridgev2.MatrixTyping) error { - if c.config().OnTyping != nil { - c.config().OnTyping(c.getSession(), c.conv(ctx, msg.Portal), msg.IsTyping) +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixTyping(ctx context.Context, msg *bridgev2.MatrixTyping) error { + if c.cfg != nil && c.cfg.OnTyping != nil { + c.cfg.OnTyping(c.getSession(), c.conv(ctx, msg.Portal), msg.IsTyping) } return nil } // HandleMatrixRoomName implements bridgev2.RoomNameHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixRoomName(ctx context.Context, msg *bridgev2.MatrixRoomName) (bool, error) { - if c.config().OnRoomName != nil { - return c.config().OnRoomName(c.getSession(), c.conv(ctx, msg.Portal), msg.Content.Name) +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixRoomName(ctx context.Context, msg *bridgev2.MatrixRoomName) (bool, error) { + if c.cfg != nil && c.cfg.OnRoomName != nil { + return c.cfg.OnRoomName(c.getSession(), c.conv(ctx, msg.Portal), msg.Content.Name) } return false, nil } // HandleMatrixRoomTopic implements bridgev2.RoomTopicHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixRoomTopic(ctx context.Context, msg *bridgev2.MatrixRoomTopic) (bool, error) { - if c.config().OnRoomTopic != nil { - return c.config().OnRoomTopic(c.getSession(), c.conv(ctx, msg.Portal), msg.Content.Topic) +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixRoomTopic(ctx context.Context, msg *bridgev2.MatrixRoomTopic) (bool, error) { + if c.cfg != nil && c.cfg.OnRoomTopic != nil { + return c.cfg.OnRoomTopic(c.getSession(), c.conv(ctx, msg.Portal), msg.Content.Topic) } return false, nil } // FetchMessages implements bridgev2.BackfillingNetworkAPI. -func (c *sdkClient) FetchMessages(ctx context.Context, params bridgev2.FetchMessagesParams) (*bridgev2.FetchMessagesResponse, error) { - if c.config().FetchMessages == nil { +func (c *sdkClient[SessionT, ConfigDataT]) FetchMessages(ctx context.Context, params bridgev2.FetchMessagesParams) (*bridgev2.FetchMessagesResponse, error) { + if c.cfg == nil || c.cfg.FetchMessages == nil { return nil, nil } - return c.config().FetchMessages(ctx, params) + return c.cfg.FetchMessages(ctx, params) } // HandleMatrixDeleteChat implements bridgev2.DeleteChatHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { - if c.config().DeleteChat == nil { +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { + if c.cfg == nil || c.cfg.DeleteChat == nil { return nil } - return c.config().DeleteChat(c.conv(ctx, msg.Portal)) + return c.cfg.DeleteChat(c.conv(ctx, msg.Portal)) } // ResolveIdentifier implements bridgev2.IdentifierResolvingNetworkAPI. -func (c *sdkClient) ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { - if c.config().ResolveIdentifier == nil { +func (c *sdkClient[SessionT, ConfigDataT]) ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { + if c.cfg == nil || c.cfg.ResolveIdentifier == nil { return nil, nil } - return c.config().ResolveIdentifier(ctx, c.getSession(), identifier, createChat) + return c.cfg.ResolveIdentifier(ctx, c.getSession(), identifier, createChat) } // GetContactList implements bridgev2.ContactListingNetworkAPI. -func (c *sdkClient) GetContactList(ctx context.Context) ([]*bridgev2.ResolveIdentifierResponse, error) { - if c.config().GetContactList == nil { +func (c *sdkClient[SessionT, ConfigDataT]) GetContactList(ctx context.Context) ([]*bridgev2.ResolveIdentifierResponse, error) { + if c.cfg == nil || c.cfg.GetContactList == nil { return nil, nil } - return c.config().GetContactList(ctx, c.getSession()) + return c.cfg.GetContactList(ctx, c.getSession()) } // SearchUsers implements bridgev2.UserSearchingNetworkAPI. -func (c *sdkClient) SearchUsers(ctx context.Context, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { - if c.config().SearchUsers == nil { +func (c *sdkClient[SessionT, ConfigDataT]) SearchUsers(ctx context.Context, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { + if c.cfg == nil || c.cfg.SearchUsers == nil { return nil, nil } - return c.config().SearchUsers(ctx, c.getSession(), query) + return c.cfg.SearchUsers(ctx, c.getSession(), query) } diff --git a/sdk/client_resolution_test.go b/sdk/client_resolution_test.go index cca50dab..ba9d44e6 100644 --- a/sdk/client_resolution_test.go +++ b/sdk/client_resolution_test.go @@ -14,8 +14,8 @@ func TestSDKClientResolveIdentifierPreservesFullResponse(t *testing.T) { chat := &bridgev2.CreateChatResponse{ PortalKey: networkid.PortalKey{ID: "portal-1", Receiver: "login-1"}, } - cfg := &Config{ - ResolveIdentifier: func(_ context.Context, _ any, id string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { + cfg := &Config[*bridgev2.UserLogin, *struct{}]{ + ResolveIdentifier: func(_ context.Context, _ *bridgev2.UserLogin, id string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { if id != "agent:test" { t.Fatalf("unexpected identifier %q", id) } @@ -50,11 +50,11 @@ func TestSDKClientResolveIdentifierPreservesFullResponse(t *testing.T) { func TestSDKClientContactListingAndSearch(t *testing.T) { contact := &bridgev2.ResolveIdentifierResponse{UserID: "agent-user"} - cfg := &Config{ - GetContactList: func(_ context.Context, _ any) ([]*bridgev2.ResolveIdentifierResponse, error) { + cfg := &Config[*bridgev2.UserLogin, *struct{}]{ + GetContactList: func(_ context.Context, _ *bridgev2.UserLogin) ([]*bridgev2.ResolveIdentifierResponse, error) { return []*bridgev2.ResolveIdentifierResponse{contact}, nil }, - SearchUsers: func(_ context.Context, _ any, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { + SearchUsers: func(_ context.Context, _ *bridgev2.UserLogin, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { if query != "agent" { t.Fatalf("unexpected query %q", query) } diff --git a/sdk/commands.go b/sdk/commands.go index 9e021239..aadb7b74 100644 --- a/sdk/commands.go +++ b/sdk/commands.go @@ -15,7 +15,7 @@ import ( var sdkHelpSection = commands.HelpSection{Name: "SDK", Order: 50} // registerCommands registers Config.Commands with the bridgev2 command processor. -func registerCommands(br *bridgev2.Bridge, cfg *Config) { +func registerCommands[SessionT SessionValue, ConfigDataT ConfigValue](br *bridgev2.Bridge, cfg *Config[SessionT, ConfigDataT]) { if len(cfg.Commands) == 0 || br == nil { return } diff --git a/sdk/connector.go b/sdk/connector.go index 2e59cf42..65fbf499 100644 --- a/sdk/connector.go +++ b/sdk/connector.go @@ -15,7 +15,7 @@ import ( ) // NewConnectorBase builds an SDK-backed connector base that can be embedded by custom bridges. -func NewConnectorBase(cfg *Config) *agentremote.ConnectorBase { +func NewConnectorBase[SessionT SessionValue, ConfigDataT ConfigValue](cfg *Config[SessionT, ConfigDataT]) *agentremote.ConnectorBase { mu, clientsRef := cfg.ClientCacheMu, cfg.ClientCache if mu == nil { mu = &sync.Mutex{} @@ -44,7 +44,7 @@ func NewConnectorBase(cfg *Config) *agentremote.ConnectorBase { cfg.UpdateClient(client, login) return } - if typed, ok := client.(*sdkClient); ok { + if typed, ok := client.(*sdkClient[SessionT, ConfigDataT]); ok { typed.SetUserLogin(login) } }, diff --git a/sdk/connector_helpers.go b/sdk/connector_helpers.go index cf4f0798..3583abf4 100644 --- a/sdk/connector_helpers.go +++ b/sdk/connector_helpers.go @@ -16,13 +16,18 @@ import ( ) // BuildStandardMetaTypes returns the common bridge metadata registrations. -func BuildStandardMetaTypes( - newPortal func() any, - newMessage func() any, - newLogin func() any, - newGhost func() any, +func BuildStandardMetaTypes[PortalT, MessageT, LoginT, GhostT any]( + newPortal func() PortalT, + newMessage func() MessageT, + newLogin func() LoginT, + newGhost func() GhostT, ) database.MetaTypes { - return agentremote.BuildMetaTypes(newPortal, newMessage, newLogin, newGhost) + return agentremote.BuildMetaTypes( + func() any { return newPortal() }, + func() any { return newMessage() }, + func() any { return newLogin() }, + func() any { return newGhost() }, + ) } // ApplyDefaultCommandPrefix sets the command prefix when it is empty. @@ -88,7 +93,7 @@ func TypedClientUpdater[T interface { } } -type StandardConnectorConfigParams struct { +type StandardConnectorConfigParams[SessionT SessionValue, ConfigDataT ConfigValue, PortalT, MessageT, LoginT, GhostT any] struct { Name string Description string ProtocolID string @@ -96,7 +101,7 @@ type StandardConnectorConfigParams struct { ClientCacheMu *sync.Mutex ClientCache *map[networkid.UserLoginID]bridgev2.NetworkAPI AgentCatalog AgentCatalog - GetCapabilities func(session any, conv *Conversation) *RoomFeatures + GetCapabilities func(session SessionT, conv *Conversation) *RoomFeatures InitConnector func(br *bridgev2.Bridge) StartConnector func(ctx context.Context, br *bridgev2.Bridge) error StopConnector func(ctx context.Context, br *bridgev2.Bridge) @@ -108,12 +113,12 @@ type StandardConnectorConfigParams struct { DefaultPort uint16 DefaultCommandPrefix func() string ExampleConfig string - ConfigData any + ConfigData ConfigDataT ConfigUpgrader configupgrade.Upgrader - NewPortal func() any - NewMessage func() any - NewLogin func() any - NewGhost func() any + NewPortal func() PortalT + NewMessage func() MessageT + NewLogin func() LoginT + NewGhost func() GhostT NetworkCapabilities func() *bridgev2.NetworkGeneralCapabilities FillBridgeInfo func(portal *bridgev2.Portal, content *event.BridgeEventContent) AcceptLogin func(login *bridgev2.UserLogin) (bool, string) @@ -129,8 +134,8 @@ type StandardConnectorConfigParams struct { // NewStandardConnectorConfig builds the common bridgesdk.Config skeleton used by // the dedicated bridge connectors. -func NewStandardConnectorConfig(p StandardConnectorConfigParams) *Config { - return &Config{ +func NewStandardConnectorConfig[SessionT SessionValue, ConfigDataT ConfigValue, PortalT, MessageT, LoginT, GhostT any](p StandardConnectorConfigParams[SessionT, ConfigDataT, PortalT, MessageT, LoginT, GhostT]) *Config[SessionT, ConfigDataT] { + return &Config[SessionT, ConfigDataT]{ Name: p.Name, Description: p.Description, ProtocolID: p.ProtocolID, diff --git a/sdk/connector_hooks_test.go b/sdk/connector_hooks_test.go index cf6ed7ad..6670ad03 100644 --- a/sdk/connector_hooks_test.go +++ b/sdk/connector_hooks_test.go @@ -59,7 +59,7 @@ func TestNewConnectorBaseUsesHooksAndCustomClients(t *testing.T) { afterLoadCalled := 0 wantBridge := &bridgev2.Bridge{} - cfg := &Config{ + cfg := &Config[*struct{}, *struct{}]{ Name: "hooked", ClientCacheMu: &mu, ClientCache: &clients, @@ -145,7 +145,7 @@ func TestNewConnectorBaseUsesHooksAndCustomClients(t *testing.T) { func TestNewConnectorBaseUsesCustomLoadLoginAndLoginFlows(t *testing.T) { loadCalled := 0 - cfg := &Config{ + cfg := &Config[*struct{}, *struct{}]{ Name: "custom-load", LoadLogin: func(_ context.Context, login *bridgev2.UserLogin) error { loadCalled++ @@ -188,7 +188,7 @@ func TestNewConnectorBaseUsesCustomLoadLoginAndLoginFlows(t *testing.T) { } func TestApprovalControllerUsesCustomHandler(t *testing.T) { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{}, nil) + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{}, nil) turn := conv.StartTurn(context.Background(), &Agent{ID: "agent"}, nil) called := false diff --git a/sdk/conversation.go b/sdk/conversation.go index ba8ec271..b031c650 100644 --- a/sdk/conversation.go +++ b/sdk/conversation.go @@ -60,13 +60,6 @@ func (c *Conversation) getIntent(ctx context.Context) (bridgev2.MatrixAPI, error return intent, nil } -func (c *Conversation) configOrNil() *Config { - if c.runtime == nil { - return nil - } - return c.runtime.config() -} - func (c *Conversation) stateStore() *conversationStateStore { if c == nil || c.runtime == nil { return nil @@ -97,15 +90,14 @@ func (c *Conversation) resolveDefaultAgent(ctx context.Context) (*Agent, error) return agent, nil } } - cfg := c.configOrNil() - if cfg == nil { + if c.runtime == nil { return nil, nil } - if cfg.Agent != nil { - return cfg.Agent, nil + if agent := c.runtime.agent(); agent != nil { + return agent, nil } - if cfg.AgentCatalog != nil { - return cfg.AgentCatalog.DefaultAgent(ctx, c.login) + if catalog := c.runtime.agentCatalog(); catalog != nil { + return catalog.DefaultAgent(ctx, c.login) } return nil, nil } @@ -114,15 +106,14 @@ func (c *Conversation) resolveAgentByIdentifier(ctx context.Context, identifier if c == nil || strings.TrimSpace(identifier) == "" { return nil, nil } - cfg := c.configOrNil() - if cfg == nil { + if c.runtime == nil { return nil, nil } - if cfg.Agent != nil && cfg.Agent.ID == identifier { - return cfg.Agent, nil + if agent := c.runtime.agent(); agent != nil && agent.ID == identifier { + return agent, nil } - if cfg.AgentCatalog != nil { - return cfg.AgentCatalog.ResolveAgent(ctx, c.login, identifier) + if catalog := c.runtime.agentCatalog(); catalog != nil { + return catalog.ResolveAgent(ctx, c.login, identifier) } return nil, nil } @@ -131,9 +122,8 @@ func (c *Conversation) currentRoomFeatures(ctx context.Context) *RoomFeatures { if c == nil { return nil } - cfg := c.configOrNil() - if cfg != nil && cfg.GetCapabilities != nil { - if rf := cfg.GetCapabilities(c.runtime.sessionValue(), c); rf != nil { + if c.runtime != nil { + if rf := c.runtime.roomFeatures(c); rf != nil { return rf } } @@ -152,9 +142,6 @@ func (c *Conversation) currentRoomFeatures(ctx context.Context) *RoomFeatures { } } if len(agents) == 0 { - if cfg != nil && cfg.RoomFeatures != nil { - return cfg.RoomFeatures - } return defaultSDKFeatureConfig() } return computeRoomFeaturesForAgents(agents) @@ -253,14 +240,6 @@ func (c *Conversation) StartTurn(ctx context.Context, agent *Agent, source *Sour return newTurn(ctx, c, agent, source) } -// Session returns the session state from the client, if available. -func (c *Conversation) Session() any { - if c.runtime == nil { - return nil - } - return c.runtime.sessionValue() -} - // Context returns the conversation's context. func (c *Conversation) Context() context.Context { return c.ctx diff --git a/sdk/conversation_test.go b/sdk/conversation_test.go index d6800614..1b7697ce 100644 --- a/sdk/conversation_test.go +++ b/sdk/conversation_test.go @@ -25,7 +25,7 @@ func (c testAgentCatalog) ResolveAgent(_ context.Context, _ *bridgev2.UserLogin, return c.byIdentifier[identifier], nil } -func newTestConversation(cfg *Config, state sdkConversationState) *Conversation { +func newTestConversation(cfg *Config[struct{}, *struct{}], state sdkConversationState) *Conversation { return newConversation( context.Background(), &bridgev2.Portal{ @@ -36,12 +36,12 @@ func newTestConversation(cfg *Config, state sdkConversationState) *Conversation }, nil, bridgev2.EventSender{}, - &staticRuntime{cfg: cfg}, + &staticRuntime[struct{}, *struct{}]{cfg: cfg}, ) } func TestConversationCurrentRoomFeaturesUsesConfiguredDefaultAgent(t *testing.T) { - conv := newTestConversation(&Config{ + conv := newTestConversation(&Config[struct{}, *struct{}]{ Agent: &Agent{ ID: "default", Capabilities: AgentCapabilities{ @@ -61,7 +61,7 @@ func TestConversationCurrentRoomFeaturesUsesConfiguredDefaultAgent(t *testing.T) } func TestConversationCurrentRoomFeaturesFallsBackAfterUnresolvedAgents(t *testing.T) { - conv := newTestConversation(&Config{ + conv := newTestConversation(&Config[struct{}, *struct{}]{ Agent: &Agent{ ID: "default", Capabilities: AgentCapabilities{ @@ -83,7 +83,7 @@ func TestConversationCurrentRoomFeaturesFallsBackAfterUnresolvedAgents(t *testin } func TestConversationCurrentRoomFeaturesIgnoresUnresolvedAgentsWhenOneResolves(t *testing.T) { - conv := newTestConversation(&Config{ + conv := newTestConversation(&Config[struct{}, *struct{}]{ AgentCatalog: testAgentCatalog{ byIdentifier: map[string]*Agent{ "found": { diff --git a/sdk/login_handle.go b/sdk/login_handle.go index a91710a8..94feb9b8 100644 --- a/sdk/login_handle.go +++ b/sdk/login_handle.go @@ -70,10 +70,10 @@ func (l *LoginHandle) EnsureConversation(ctx context.Context, spec ConversationS AIRoomKind: conv.aiRoomKind(), ForceCapabilities: true, RefreshExtra: func(ctx context.Context, portal *bridgev2.Portal) { - if l.runtime == nil || l.runtime.config() == nil || len(l.runtime.config().Commands) == 0 { + if l.runtime == nil || len(l.runtime.commands()) == 0 { return } - BroadcastCommandDescriptions(ctx, portal, l.login.Bridge.Bot, l.runtime.config().Commands) + BroadcastCommandDescriptions(ctx, portal, l.login.Bridge.Bot, l.runtime.commands()) }, }) if err != nil { diff --git a/sdk/part_apply_test.go b/sdk/part_apply_test.go index 3ccf9a1c..25c478dd 100644 --- a/sdk/part_apply_test.go +++ b/sdk/part_apply_test.go @@ -8,7 +8,7 @@ import ( ) func newPartApplyTestTurn() *Turn { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{}, nil) + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{}, nil) return conv.StartTurn(context.Background(), &Agent{ID: "agent"}, nil) } diff --git a/sdk/runtime.go b/sdk/runtime.go index 1a2c9447..21bd6f3a 100644 --- a/sdk/runtime.go +++ b/sdk/runtime.go @@ -9,36 +9,75 @@ import ( ) type conversationRuntime interface { - config() *Config - sessionValue() any + agent() *Agent + agentCatalog() AgentCatalog + roomFeatures(conv *Conversation) *RoomFeatures + commands() []Command + turnConfig() *TurnConfig conversationStore() *conversationStateStore approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] providerIdentity() ProviderIdentity } -type staticRuntime struct { - cfg *Config - session any +type staticRuntime[SessionT SessionValue, ConfigDataT ConfigValue] struct { + cfg *Config[SessionT, ConfigDataT] + session SessionT login *bridgev2.UserLogin store *conversationStateStore approval *agentremote.ApprovalFlow[*pendingSDKApprovalData] } -func (r *staticRuntime) config() *Config { return r.cfg } +func (r *staticRuntime[SessionT, ConfigDataT]) agent() *Agent { + if r == nil || r.cfg == nil { + return nil + } + return r.cfg.Agent +} + +func (r *staticRuntime[SessionT, ConfigDataT]) agentCatalog() AgentCatalog { + if r == nil || r.cfg == nil { + return nil + } + return r.cfg.AgentCatalog +} + +func (r *staticRuntime[SessionT, ConfigDataT]) roomFeatures(conv *Conversation) *RoomFeatures { + if r == nil || r.cfg == nil { + return nil + } + if r.cfg.GetCapabilities != nil { + if rf := r.cfg.GetCapabilities(r.session, conv); rf != nil { + return rf + } + } + return r.cfg.RoomFeatures +} -func (r *staticRuntime) sessionValue() any { return r.session } +func (r *staticRuntime[SessionT, ConfigDataT]) commands() []Command { + if r == nil || r.cfg == nil { + return nil + } + return r.cfg.Commands +} + +func (r *staticRuntime[SessionT, ConfigDataT]) turnConfig() *TurnConfig { + if r == nil || r.cfg == nil { + return nil + } + return r.cfg.TurnManagement +} -func (r *staticRuntime) conversationStore() *conversationStateStore { return r.store } +func (r *staticRuntime[SessionT, ConfigDataT]) conversationStore() *conversationStateStore { return r.store } -func (r *staticRuntime) approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] { +func (r *staticRuntime[SessionT, ConfigDataT]) approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] { return r.approval } -func (r *staticRuntime) providerIdentity() ProviderIdentity { +func (r *staticRuntime[SessionT, ConfigDataT]) providerIdentity() ProviderIdentity { return resolveProviderIdentity(r.cfg) } -func resolveProviderIdentity(cfg *Config) ProviderIdentity { +func resolveProviderIdentity[SessionT SessionValue, ConfigDataT ConfigValue](cfg *Config[SessionT, ConfigDataT]) ProviderIdentity { if cfg == nil { return normalizedProviderIdentity(ProviderIdentity{}) } @@ -65,8 +104,8 @@ type NewConversationOptions struct { // NewConversation creates an SDK conversation wrapper for provider bridges that // want to drive SDK turns without using the default sdkClient implementation. -func NewConversation(ctx context.Context, login *bridgev2.UserLogin, portal *bridgev2.Portal, sender bridgev2.EventSender, cfg *Config, session any, opts ...NewConversationOptions) *Conversation { - rt := &staticRuntime{ +func NewConversation[SessionT SessionValue, ConfigDataT ConfigValue](ctx context.Context, login *bridgev2.UserLogin, portal *bridgev2.Portal, sender bridgev2.EventSender, cfg *Config[SessionT, ConfigDataT], session SessionT, opts ...NewConversationOptions) *Conversation { + rt := &staticRuntime[SessionT, ConfigDataT]{ cfg: cfg, session: session, login: login, diff --git a/sdk/turn.go b/sdk/turn.go index 4a21c9b9..015d45f4 100644 --- a/sdk/turn.go +++ b/sdk/turn.go @@ -933,10 +933,10 @@ func (t *Turn) ensureDefaultFinalEditPayload(finishReason, fallbackBody string) func (t *Turn) resolvedIdleTimeout() time.Duration { const defaultIdleTimeout = time.Minute - if t == nil || t.conv == nil || t.conv.runtime == nil || t.conv.runtime.config() == nil || t.conv.runtime.config().TurnManagement == nil { + if t == nil || t.conv == nil || t.conv.runtime == nil || t.conv.runtime.turnConfig() == nil { return defaultIdleTimeout } - timeoutMs := t.conv.runtime.config().TurnManagement.IdleTimeoutMs + timeoutMs := t.conv.runtime.turnConfig().IdleTimeoutMs switch { case timeoutMs < 0: return 0 diff --git a/sdk/turn_test.go b/sdk/turn_test.go index 006e7e2c..be24a5da 100644 --- a/sdk/turn_test.go +++ b/sdk/turn_test.go @@ -192,7 +192,7 @@ func TestTurnRequestApprovalWaitsForResolvedDecision(t *testing.T) { UserMXID: "@owner:test", }, } - runtime := &staticRuntime{ + runtime := &staticRuntime[*struct{}, *struct{}]{ login: login, approval: agentremote.NewApprovalFlow(agentremote.ApprovalFlowConfig[*pendingSDKApprovalData]{ Login: func() *bridgev2.UserLogin { return nil }, @@ -248,7 +248,7 @@ func TestTurnRequestApprovalUsesProvidedApprovalID(t *testing.T) { UserMXID: "@owner:test", }, } - runtime := &staticRuntime{ + runtime := &staticRuntime[*struct{}, *struct{}]{ login: login, approval: agentremote.NewApprovalFlow(agentremote.ApprovalFlowConfig[*pendingSDKApprovalData]{ Login: func() *bridgev2.UserLogin { return nil }, @@ -276,7 +276,7 @@ func TestTurnRequestApprovalUsesProvidedApprovalID(t *testing.T) { } func TestTurnStreamSetTransportReceivesEvents(t *testing.T) { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{}, nil) + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{}, nil) turn := conv.StartTurn(context.Background(), &Agent{ID: "agent"}, nil) var gotTurnID string @@ -685,7 +685,7 @@ func TestTurnBuildFinalEditUsesErrorTextFallback(t *testing.T) { } func TestTurnIdleTimeoutAbortsStuckTurn(t *testing.T) { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{ + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{ TurnManagement: &TurnConfig{IdleTimeoutMs: 20}, }, nil) turn := conv.StartTurn(context.Background(), nil, nil) @@ -704,7 +704,7 @@ func TestTurnIdleTimeoutAbortsStuckTurn(t *testing.T) { } func TestTurnIdleTimeoutResetsOnActivity(t *testing.T) { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{ + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{ TurnManagement: &TurnConfig{IdleTimeoutMs: 40}, }, nil) turn := conv.StartTurn(context.Background(), nil, nil) diff --git a/sdk/types.go b/sdk/types.go index ecccca1a..4065bb32 100644 --- a/sdk/types.go +++ b/sdk/types.go @@ -215,8 +215,12 @@ type ProviderIdentity struct { StatusNetwork string } +type SessionValue interface{} + +type ConfigValue interface{} + // Config configures the SDK bridge. -type Config struct { +type Config[SessionT SessionValue, ConfigDataT ConfigValue] struct { // Required Name string Description string @@ -229,29 +233,29 @@ type Config struct { // Message handling (required) // session is the value returned by OnConnect; conv is the conversation; // msg is the incoming message; turn is the pre-created Turn for streaming responses. - OnMessage func(session any, conv *Conversation, msg *Message, turn *Turn) error + OnMessage func(session SessionT, conv *Conversation, msg *Message, turn *Turn) error // Event hooks (optional) - OnConnect func(ctx context.Context, login *LoginInfo) (any, error) // returns session state - OnDisconnect func(session any) - OnReaction func(session any, conv *Conversation, reaction *Reaction) error - OnTyping func(session any, conv *Conversation, typing bool) - OnEdit func(session any, conv *Conversation, edit *MessageEdit) error - OnDelete func(session any, conv *Conversation, msgID string) error - OnRoomName func(session any, conv *Conversation, name string) (bool, error) - OnRoomTopic func(session any, conv *Conversation, topic string) (bool, error) + OnConnect func(ctx context.Context, login *LoginInfo) (SessionT, error) // returns session state + OnDisconnect func(session SessionT) + OnReaction func(session SessionT, conv *Conversation, reaction *Reaction) error + OnTyping func(session SessionT, conv *Conversation, typing bool) + OnEdit func(session SessionT, conv *Conversation, edit *MessageEdit) error + OnDelete func(session SessionT, conv *Conversation, msgID string) error + OnRoomName func(session SessionT, conv *Conversation, name string) (bool, error) + OnRoomTopic func(session SessionT, conv *Conversation, topic string) (bool, error) // Turn management (optional) TurnManagement *TurnConfig // Capabilities (optional, dynamic per-conversation) - GetCapabilities func(session any, conv *Conversation) *RoomFeatures + GetCapabilities func(session SessionT, conv *Conversation) *RoomFeatures // Search & chat ops (optional) - SearchUsers func(ctx context.Context, session any, query string) ([]*bridgev2.ResolveIdentifierResponse, error) - GetContactList func(ctx context.Context, session any) ([]*bridgev2.ResolveIdentifierResponse, error) - ResolveIdentifier func(ctx context.Context, session any, id string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) - CreateChat func(ctx context.Context, session any, params *CreateChatParams) (*bridgev2.CreateChatResponse, error) + SearchUsers func(ctx context.Context, session SessionT, query string) ([]*bridgev2.ResolveIdentifierResponse, error) + GetContactList func(ctx context.Context, session SessionT) ([]*bridgev2.ResolveIdentifierResponse, error) + ResolveIdentifier func(ctx context.Context, session SessionT, id string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) + CreateChat func(ctx context.Context, session SessionT, params *CreateChatParams) (*bridgev2.CreateChatResponse, error) DeleteChat func(conv *Conversation) error GetChatInfo func(conv *Conversation) (*bridgev2.ChatInfo, error) GetUserInfo func(ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) @@ -296,6 +300,6 @@ type Config struct { ConfigPath string // default: auto-discover DBMeta func() database.MetaTypes // nil = default ExampleConfig string // YAML - ConfigData any // config struct pointer + ConfigData ConfigDataT // config struct pointer ConfigUpgrader configupgrade.Upgrader } From da8935a258fe39b53f0975cafc11bdc74c5cbf5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 23:46:09 +0200 Subject: [PATCH 13/23] Refactor integration types and introduce generics Migrate integration and connector code to stronger typed APIs and generics. Key changes: - Use integrationruntime.ModuleHooks and concrete integrationruntime types instead of generic any for module registration, scopes and callbacks. - Update bridgesdk.Config usages to the generic form (e.g. bridgesdk.Config[*AIClient, *Config]) and adjust NewStandardConnectorConfig parameter factories to return concrete pointer types. - Replace many host/handler signatures to accept concrete types (e.g. *bridgev2.Portal, *PortalMetadata, *dbutil.Database) and remove redundant runtime casts. - Add PortalMetadata helper accessors (AgentID, CompactionCounter, InternalRoom, ModuleMetaValue, SetModuleMetaValue). - Improve compaction retry logic by caching token estimates from preflight flush, return an int from the preflight hook, and emit lifecycle events without embedding client references. - Simplify header handling in media understanding helpers to avoid unnecessary copying. - Propagate truncated flag into file message XML and simplify XML building. - Rename and tighten dummybridge session helpers (sessionFromAny -> requireSession) and update related APIs to use typed sessions. - Adjust tests to use the new generic bridgesdk.NewConversation signatures and other updated APIs. These changes tighten type safety, reduce runtime type assertions, and prepare the codebase for clearer integration APIs. --- bridges/ai/turn_validation.go | 208 --------------------------- bridges/ai/turn_validation_test.go | 113 --------------- pkg/integrations/cron/integration.go | 11 +- sdk/runtime.go | 4 +- 4 files changed, 5 insertions(+), 331 deletions(-) delete mode 100644 bridges/ai/turn_validation.go delete mode 100644 bridges/ai/turn_validation_test.go diff --git a/bridges/ai/turn_validation.go b/bridges/ai/turn_validation.go deleted file mode 100644 index ed3cf51e..00000000 --- a/bridges/ai/turn_validation.go +++ /dev/null @@ -1,208 +0,0 @@ -package ai - -import ( - "strings" - - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/packages/param" -) - -// IsGoogleModel returns true if the model ID looks like a Google/Gemini model. -func IsGoogleModel(modelID string) bool { - lower := strings.ToLower(modelID) - return strings.HasPrefix(lower, "google/") || - strings.HasPrefix(lower, "gemini") || - strings.Contains(lower, "/gemini") -} - -// SanitizeGoogleTurnOrdering fixes prompt ordering for Google models: -// - Merges consecutive user messages -// - Merges consecutive assistant messages -// - Prepends a synthetic user turn if history starts with an assistant message -func SanitizeGoogleTurnOrdering(prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion { - if len(prompt) == 0 { - return prompt - } - - // Separate system messages (keep at front) from conversation messages - var system []openai.ChatCompletionMessageParamUnion - var conversation []openai.ChatCompletionMessageParamUnion - for _, msg := range prompt { - if chatMessageRole(msg) == "system" { - system = append(system, msg) - } else { - conversation = append(conversation, msg) - } - } - - if len(conversation) == 0 { - return prompt - } - - // Merge consecutive same-role messages - merged := mergeConsecutiveSameRole(conversation) - - // If the first non-system message is assistant, prepend a synthetic user turn - if len(merged) > 0 && chatMessageRole(merged[0]) == "assistant" { - merged = append([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("(continued from previous session)"), - }, merged...) - } - - return append(system, merged...) -} - -// mergeConsecutiveSameRole combines adjacent messages with the same role. -// For user messages, it preserves multimodal content parts (images, etc.) -// by collecting all parts from the run into a single OfArrayOfContentParts message. -// For assistant messages, it concatenates text bodies with double newlines. -func mergeConsecutiveSameRole(msgs []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion { - if len(msgs) <= 1 { - return msgs - } - - var result []openai.ChatCompletionMessageParamUnion - i := 0 - for i < len(msgs) { - role := chatMessageRole(msgs[i]) - j := i + 1 - for j < len(msgs) && chatMessageRole(msgs[j]) == role { - j++ - } - - // Single message, no merging needed — keep as-is. - if j == i+1 { - result = append(result, msgs[i]) - i = j - continue - } - - // Multiple consecutive messages with the same role — merge them. - run := msgs[i:j] - switch role { - case "user": - result = append(result, mergeUserMessages(run)) - case "assistant": - // Assistant messages are always text-only (images go in synthetic user messages). - var body string - for _, m := range run { - nextBody := chatMessageBody(m) - if nextBody != "" { - if body != "" { - body += "\n\n" - } - body += nextBody - } - } - result = append(result, openai.AssistantMessage(body)) - default: - // For other roles, concatenate text. - var body string - for _, m := range run { - nextBody := chatMessageBody(m) - if nextBody != "" { - if body != "" { - body += "\n\n" - } - body += nextBody - } - } - result = append(result, openai.UserMessage(body)) - } - i = j - } - return result -} - -// mergeUserMessages merges a run of consecutive user messages into one, -// preserving multimodal content parts (OfArrayOfContentParts) if any message has them. -func mergeUserMessages(run []openai.ChatCompletionMessageParamUnion) openai.ChatCompletionMessageParamUnion { - // Check if any message in the run has multimodal parts. - hasMultimodal := false - for _, m := range run { - if m.OfUser != nil && len(m.OfUser.Content.OfArrayOfContentParts) > 0 { - hasMultimodal = true - break - } - } - - // If no multimodal content, do the simple text merge. - if !hasMultimodal { - var body string - for _, m := range run { - nextBody := chatMessageBody(m) - if nextBody != "" { - if body != "" { - body += "\n\n" - } - body += nextBody - } - } - return openai.UserMessage(body) - } - - // Collect all content parts, converting plain-text messages to text parts. - var allParts []openai.ChatCompletionContentPartUnionParam - for idx, m := range run { - if m.OfUser == nil { - continue - } - if len(m.OfUser.Content.OfArrayOfContentParts) > 0 { - // Add a separator text part between merged messages (except the first). - if idx > 0 && len(allParts) > 0 { - allParts = append(allParts, openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{Text: "\n\n"}, - }) - } - allParts = append(allParts, m.OfUser.Content.OfArrayOfContentParts...) - } else if m.OfUser.Content.OfString.Value != "" { - if len(allParts) > 0 { - allParts = append(allParts, openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{Text: "\n\n"}, - }) - } - allParts = append(allParts, openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{Text: m.OfUser.Content.OfString.Value}, - }) - } - } - - return openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfArrayOfContentParts: allParts, - }, - }, - } -} - -// chatMessageRole extracts the role string from a ChatCompletionMessageParamUnion. -// GetRole() returns empty strings at construction time (constant types marshal lazily), -// so we check which Of* field is populated instead. -func chatMessageRole(msg openai.ChatCompletionMessageParamUnion) string { - if !param.IsOmitted(msg.OfSystem) { - return "system" - } - if !param.IsOmitted(msg.OfUser) { - return "user" - } - if !param.IsOmitted(msg.OfAssistant) { - return "assistant" - } - if !param.IsOmitted(msg.OfTool) { - return "tool" - } - if !param.IsOmitted(msg.OfDeveloper) { - return "developer" - } - return "user" -} - -// chatMessageBody extracts the text body from a ChatCompletionMessageParamUnion. -func chatMessageBody(msg openai.ChatCompletionMessageParamUnion) string { - c := msg.GetContent() - if s, ok := c.AsAny().(*string); ok && s != nil { - return *s - } - return "" -} diff --git a/bridges/ai/turn_validation_test.go b/bridges/ai/turn_validation_test.go deleted file mode 100644 index 80ca9af6..00000000 --- a/bridges/ai/turn_validation_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package ai - -import ( - "testing" - - "github.com/openai/openai-go/v3" -) - -func TestIsGoogleModel(t *testing.T) { - tests := []struct { - model string - want bool - }{ - {"google/gemini-2.5-flash", true}, - {"google/gemini-3.1-pro-preview", true}, - {"gemini-pro", true}, - {"openrouter/google/gemini-flash", true}, - {"anthropic/claude-sonnet-4.5", false}, - {"openai/gpt-5", false}, - {"", false}, - } - for _, tt := range tests { - if got := IsGoogleModel(tt.model); got != tt.want { - t.Errorf("IsGoogleModel(%q) = %v, want %v", tt.model, got, tt.want) - } - } -} - -func TestSanitizeGoogleTurnOrdering_MergesConsecutiveUser(t *testing.T) { - prompt := []openai.ChatCompletionMessageParamUnion{ - openai.SystemMessage("system"), - openai.UserMessage("hello"), - openai.UserMessage("world"), - openai.AssistantMessage("hi"), - } - result := SanitizeGoogleTurnOrdering(prompt) - if hasConsecutiveUserOrAssistantRoles(result) { - t.Fatal("expected sanitized prompt to be valid") - } - // system + merged-user + assistant = 3 - if len(result) != 3 { - t.Fatalf("expected 3 messages, got %d", len(result)) - } -} - -func TestSanitizeGoogleTurnOrdering_PrependsSyntheticUser(t *testing.T) { - prompt := []openai.ChatCompletionMessageParamUnion{ - openai.SystemMessage("system"), - openai.AssistantMessage("I was speaking"), - openai.UserMessage("ok"), - } - result := SanitizeGoogleTurnOrdering(prompt) - if hasConsecutiveUserOrAssistantRoles(result) { - t.Fatal("expected sanitized prompt to be valid") - } - // system + synthetic-user + assistant + user = 4 - if len(result) != 4 { - t.Fatalf("expected 4 messages, got %d", len(result)) - } - // First non-system should be user - if chatMessageRole(result[1]) != "user" { - t.Fatalf("expected synthetic user message, got %s", chatMessageRole(result[1])) - } -} - -func TestSanitizeGoogleTurnOrdering_Empty(t *testing.T) { - result := SanitizeGoogleTurnOrdering(nil) - if result != nil { - t.Fatal("expected nil for nil input") - } -} - -func TestSanitizeGoogleTurnOrdering_AlreadyValid(t *testing.T) { - prompt := []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("hello"), - openai.AssistantMessage("hi"), - } - result := SanitizeGoogleTurnOrdering(prompt) - if len(result) != 2 { - t.Fatalf("expected 2 messages for already-valid prompt, got %d", len(result)) - } -} - -func TestChatMessageRole(t *testing.T) { - tests := []struct { - msg openai.ChatCompletionMessageParamUnion - want string - }{ - {openai.SystemMessage("sys"), "system"}, - {openai.UserMessage("usr"), "user"}, - {openai.AssistantMessage("ast"), "assistant"}, - } - for _, tt := range tests { - if got := chatMessageRole(tt.msg); got != tt.want { - t.Errorf("chatMessageRole() = %q, want %q", got, tt.want) - } - } -} - -func hasConsecutiveUserOrAssistantRoles(prompt []openai.ChatCompletionMessageParamUnion) bool { - lastRole := "" - for _, msg := range prompt { - role := chatMessageRole(msg) - if role == "system" { - continue - } - if role == lastRole && (role == "user" || role == "assistant") { - return true - } - lastRole = role - } - return false -} diff --git a/pkg/integrations/cron/integration.go b/pkg/integrations/cron/integration.go index f9a028a6..8cf3ae16 100644 --- a/pkg/integrations/cron/integration.go +++ b/pkg/integrations/cron/integration.go @@ -138,7 +138,7 @@ func (i *Integration) executeCronCommand(ctx context.Context, call iruntime.Comm reply("Cron add failed: %s", err.Error()) return nil } - deps := i.buildToolExecDeps(ctx, commandScopeToToolScope(call.Scope)) + deps := i.buildToolExecDeps(ctx, iruntime.ToolScope(call.Scope)) injectToolContext(&input, deps.ResolveCreateContext) if input.Delivery != nil && strings.EqualFold(strings.TrimSpace(string(input.Delivery.Mode)), "announce") && deps.ValidateDeliveryTo != nil { if err := deps.ValidateDeliveryTo(input.Delivery.To); err != nil { @@ -169,7 +169,7 @@ func (i *Integration) executeCronCommand(ctx context.Context, call iruntime.Comm reply("Cron update failed: %s", err.Error()) return nil } - deps := i.buildToolExecDeps(ctx, commandScopeToToolScope(call.Scope)) + deps := i.buildToolExecDeps(ctx, iruntime.ToolScope(call.Scope)) if patch.Delivery != nil && patch.Delivery.To != nil && deps.ValidateDeliveryTo != nil { if err := deps.ValidateDeliveryTo(*patch.Delivery.To); err != nil { reply("Cron update failed: %s", err.Error()) @@ -279,13 +279,6 @@ func (i *Integration) buildToolExecDeps(ctx context.Context, scope iruntime.Tool return deps } -func commandScopeToToolScope(scope iruntime.CommandScope) iruntime.ToolScope { - return iruntime.ToolScope{ - Portal: scope.Portal, - Meta: scope.Meta, - } -} - var ( _ iruntime.ToolIntegration = (*Integration)(nil) _ iruntime.CommandIntegration = (*Integration)(nil) diff --git a/sdk/runtime.go b/sdk/runtime.go index 21bd6f3a..f433244d 100644 --- a/sdk/runtime.go +++ b/sdk/runtime.go @@ -67,7 +67,9 @@ func (r *staticRuntime[SessionT, ConfigDataT]) turnConfig() *TurnConfig { return r.cfg.TurnManagement } -func (r *staticRuntime[SessionT, ConfigDataT]) conversationStore() *conversationStateStore { return r.store } +func (r *staticRuntime[SessionT, ConfigDataT]) conversationStore() *conversationStateStore { + return r.store +} func (r *staticRuntime[SessionT, ConfigDataT]) approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] { return r.approval From c97b2fda7fe18b83942f83c77d4a80eff4f0c6a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 29 Mar 2026 23:59:17 +0200 Subject: [PATCH 14/23] Use slices.Clone and simplify SystemPrompt trim Replace manual slice-copy patterns (append([]T{}, ...)) with slices.Clone for clarity and to avoid shared backing arrays. Clone opts.prepend, opts.append, and opts.leadingBlocks where used. Simplify trimming of SystemPrompt by assigning TrimSpace result to a variable before checking/adding the system message. Add the "slices" import. --- bridges/ai/prompt_builder.go | 9 +++++---- bridges/ai/prompt_context_local.go | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/bridges/ai/prompt_builder.go b/bridges/ai/prompt_builder.go index 3ce85c48..f0423a44 100644 --- a/bridges/ai/prompt_builder.go +++ b/bridges/ai/prompt_builder.go @@ -3,6 +3,7 @@ package ai import ( "context" "fmt" + "slices" "strings" "maunium.net/go/mautrix/bridgev2" @@ -168,7 +169,7 @@ func (oc *AIClient) buildCurrentTurnText( return PromptContext{}, "", err } - prepend := append([]string{}, opts.prepend...) + prepend := slices.Clone(opts.prepend) if portal != nil && portal.MXID != "" { reactionFeedback := DrainReactionFeedback(portal.MXID) if len(reactionFeedback) > 0 { @@ -181,7 +182,7 @@ func (oc *AIClient) buildCurrentTurnText( prepend = append(prepend, result.UntrustedPrefix) } - appendParts := append([]string{}, opts.append...) + appendParts := slices.Clone(opts.append) if opts.includeLinkScope { if linkContext := oc.buildLinkContext(ctx, userText, opts.rawEventContent); linkContext != "" { appendParts = append(appendParts, linkContext) @@ -200,8 +201,8 @@ func (oc *AIClient) buildPromptContextForTurn( eventID id.EventID, opts currentTurnPromptOptions, ) (PromptContext, error) { - appendFragments := append([]string{}, opts.append...) - leadingBlocks := append([]PromptBlock{}, opts.leadingBlocks...) + appendFragments := slices.Clone(opts.append) + leadingBlocks := slices.Clone(opts.leadingBlocks) if opts.attachment != nil { attachmentBlocks, attachmentAppend, err := oc.normalizeTurnAttachment(ctx, *opts.attachment) diff --git a/bridges/ai/prompt_context_local.go b/bridges/ai/prompt_context_local.go index b243e23b..08efd8da 100644 --- a/bridges/ai/prompt_context_local.go +++ b/bridges/ai/prompt_context_local.go @@ -125,8 +125,8 @@ func promptMessageToResponsesInputs(msg PromptMessage) responses.ResponseInputPa func promptContextToChatCompletionMessages(ctx PromptContext, supportsVideoURL bool) []openai.ChatCompletionMessageParamUnion { var messages []openai.ChatCompletionMessageParamUnion - if strings.TrimSpace(ctx.SystemPrompt) != "" { - messages = append(messages, openai.SystemMessage(strings.TrimSpace(ctx.SystemPrompt))) + if system := strings.TrimSpace(ctx.SystemPrompt); system != "" { + messages = append(messages, openai.SystemMessage(system)) } for _, msg := range ctx.Messages { switch msg.Role { From b722e265a0e3e7247ecd55b94dcdfefd7dda243e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Mon, 30 Mar 2026 01:24:26 +0200 Subject: [PATCH 15/23] Snapshot pending events and refine tool selection Add snapshot utilities to avoid races when queuing events: introduce pending_event.go with snapshotPendingEvent and deep-clone helpers for event.Content.Raw, and update handlers (debounce, matrix, regenerate, media/text handlers) to use snapshots and cloned raw maps before dispatching/queuing. Add pending_event_test to verify reply-target preservation after source mutation. Refactor builtin tool selection: introduce builtinToolPreset, model preset list, selectToolDefinitionsByName, and make selectedBuiltinToolsForTurn choose tools based on a preset (model vs agent), plus tests and helpers (testBuiltinToolClient, toolDefinitionNames) to cover various web-tool configurations. Also add/update streaming tests to use the new helpers and validate tool lists. --- bridges/ai/client.go | 9 +- bridges/ai/handlematrix.go | 25 +++--- bridges/ai/pending_event.go | 53 ++++++++++++ bridges/ai/pending_event_test.go | 49 +++++++++++ bridges/ai/streaming_request_tools_test.go | 48 ++++------- bridges/ai/streaming_tool_selection.go | 61 +++++++++++++- bridges/ai/streaming_tool_selection_test.go | 93 ++++++++++++++++----- 7 files changed, 268 insertions(+), 70 deletions(-) create mode 100644 bridges/ai/pending_event.go create mode 100644 bridges/ai/pending_event_test.go diff --git a/bridges/ai/client.go b/bridges/ai/client.go index 1d277789..e53ee8ed 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -2018,8 +2018,9 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { ctx = withInboundContext(ctx, inboundCtx) rawEventContent := map[string]any(nil) if last.Event != nil && last.Event.Content.Raw != nil { - rawEventContent = last.Event.Content.Raw + rawEventContent = clonePendingRawMap(last.Event.Content.Raw) } + pendingEvent := snapshotPendingEvent(last.Event) extraStatusEvents := make([]*event.Event, 0, len(entries)-1) if len(entries) > 1 { @@ -2077,7 +2078,7 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { } pending := pendingMessage{ - Event: last.Event, + Event: pendingEvent, Portal: last.Portal, Meta: last.Meta, InboundContext: &inboundCtx, @@ -2094,14 +2095,14 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { } queueItem := pendingQueueItem{ pending: pending, - messageID: string(last.Event.ID), + messageID: string(pendingEvent.ID), summaryLine: combinedRaw, enqueuedAt: time.Now().UnixMilli(), rawEventContent: rawEventContent, } queueSettings, _, _, _ := oc.resolveQueueSettingsForPortal(statusCtx, last.Portal, last.Meta, "", airuntime.QueueInlineOptions{}) - _, _ = oc.dispatchOrQueue(statusCtx, last.Event, last.Portal, last.Meta, nil, queueItem, queueSettings, promptContext) + _, _ = oc.dispatchOrQueue(statusCtx, pendingEvent, last.Portal, last.Meta, nil, queueItem, queueSettings, promptContext) } diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index 57f2fa6b..1a9d6429 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -250,8 +250,9 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri // Get raw event content for link previews var rawEventContent map[string]any if msg.Event != nil && msg.Event.Content.Raw != nil { - rawEventContent = msg.Event.Content.Raw + rawEventContent = clonePendingRawMap(msg.Event.Content.Raw) } + pendingEvent := snapshotPendingEvent(msg.Event) eventID := id.EventID("") if msg.Event != nil { @@ -279,7 +280,7 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri } pending := pendingMessage{ - Event: msg.Event, + Event: pendingEvent, Portal: portal, Meta: runMeta, InboundContext: &inboundCtx, @@ -300,7 +301,7 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri enqueuedAt: time.Now().UnixMilli(), rawEventContent: rawEventContent, } - dbMsg, isPending := oc.dispatchOrQueue(runCtx, msg.Event, portal, runMeta, userMessage, queueItem, queueSettings, promptContext) + dbMsg, isPending := oc.dispatchOrQueue(runCtx, pendingEvent, portal, runMeta, userMessage, queueItem, queueSettings, promptContext) return &bridgev2.MatrixMessageResponse{ DB: dbMsg, @@ -435,8 +436,9 @@ func (oc *AIClient) regenerateFromEdit( queueSettings, _, _, _ := oc.resolveQueueSettingsForPortal(ctx, portal, meta, "", airuntime.QueueInlineOptions{}) isGroup := oc.isGroupChat(ctx, portal) + pendingEvent := snapshotPendingEvent(evt) pending := pendingMessage{ - Event: evt, + Event: pendingEvent, Portal: portal, Meta: meta, Type: pendingTypeEditRegenerate, @@ -453,7 +455,7 @@ func (oc *AIClient) regenerateFromEdit( summaryLine: newBody, enqueuedAt: time.Now().UnixMilli(), } - oc.dispatchOrQueueCore(ctx, evt, portal, meta, nil, queueItem, queueSettings, promptContext) + oc.dispatchOrQueueCore(ctx, pendingEvent, portal, meta, nil, queueItem, queueSettings, promptContext) return nil } @@ -599,6 +601,7 @@ func (oc *AIClient) handleMediaMessage( if msg.Content.File != nil { encryptedFile = msg.Content.File } + pendingEvent := snapshotPendingEvent(msg.Event) dispatchTextOnly := func(rawBody string) (*bridgev2.MatrixMessageResponse, error) { inboundCtx := oc.buildMatrixInboundContext(portal, msg.Event, rawBody, senderName, roomName, isGroup) @@ -623,7 +626,7 @@ func (oc *AIClient) handleMediaMessage( userMessage.SendTxnID = networkid.RawTransactionID(msg.InputTransactionID) } pending := pendingMessage{ - Event: msg.Event, + Event: pendingEvent, Portal: portal, Meta: meta, InboundContext: &inboundCtx, @@ -638,7 +641,7 @@ func (oc *AIClient) handleMediaMessage( summaryLine: rawBody, enqueuedAt: time.Now().UnixMilli(), } - dbMsg, isPending := oc.dispatchOrQueue(promptCtx, msg.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) + dbMsg, isPending := oc.dispatchOrQueue(promptCtx, pendingEvent, portal, meta, userMessage, queueItem, queueSettings, promptContext) return &bridgev2.MatrixMessageResponse{ DB: dbMsg, Pending: isPending, @@ -746,7 +749,7 @@ func (oc *AIClient) handleMediaMessage( } pending := pendingMessage{ - Event: msg.Event, + Event: snapshotPendingEvent(msg.Event), Portal: portal, Meta: meta, InboundContext: &captionInboundCtx, @@ -764,7 +767,7 @@ func (oc *AIClient) handleMediaMessage( summaryLine: rawCaption, enqueuedAt: time.Now().UnixMilli(), } - dbMsg, isPending := oc.dispatchOrQueue(promptCtx, msg.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) + dbMsg, isPending := oc.dispatchOrQueue(promptCtx, pending.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) return &bridgev2.MatrixMessageResponse{ DB: dbMsg, @@ -891,7 +894,7 @@ func (oc *AIClient) handleTextFileMessage( } pending := pendingMessage{ - Event: msg.Event, + Event: snapshotPendingEvent(msg.Event), Portal: portal, Meta: meta, InboundContext: &inboundCtx, @@ -906,7 +909,7 @@ func (oc *AIClient) handleTextFileMessage( summaryLine: strings.TrimSpace(rawCaption), enqueuedAt: time.Now().UnixMilli(), } - dbMsg, isPending := oc.dispatchOrQueue(promptCtx, msg.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) + dbMsg, isPending := oc.dispatchOrQueue(promptCtx, pending.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) return &bridgev2.MatrixMessageResponse{ DB: dbMsg, diff --git a/bridges/ai/pending_event.go b/bridges/ai/pending_event.go new file mode 100644 index 00000000..6a357a65 --- /dev/null +++ b/bridges/ai/pending_event.go @@ -0,0 +1,53 @@ +package ai + +import "maunium.net/go/mautrix/event" + +// snapshotPendingEvent copies only the event fields that queued/goroutine-based +// reply targeting and status propagation depend on. +func snapshotPendingEvent(evt *event.Event) *event.Event { + if evt == nil { + return nil + } + cloned := &event.Event{ + ID: evt.ID, + Type: evt.Type, + Sender: evt.Sender, + } + if len(evt.Content.Raw) > 0 { + cloned.Content.Raw = clonePendingRawMap(evt.Content.Raw) + } + return cloned +} + +func clonePendingRawMap(src map[string]any) map[string]any { + if len(src) == 0 { + return nil + } + cloned := make(map[string]any, len(src)) + for k, v := range src { + cloned[k] = clonePendingRawValue(v) + } + return cloned +} + +func clonePendingRawSlice(src []any) []any { + if len(src) == 0 { + return nil + } + cloned := make([]any, len(src)) + for i, v := range src { + cloned[i] = clonePendingRawValue(v) + } + return cloned +} + +func clonePendingRawValue(v any) any { + switch typed := v.(type) { + case map[string]any: + return clonePendingRawMap(typed) + case []any: + return clonePendingRawSlice(typed) + default: + return v + } +} diff --git a/bridges/ai/pending_event_test.go b/bridges/ai/pending_event_test.go new file mode 100644 index 00000000..2ed950c2 --- /dev/null +++ b/bridges/ai/pending_event_test.go @@ -0,0 +1,49 @@ +package ai + +import ( + "context" + "testing" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestSnapshotPendingEventPreservesReplyTargetAfterSourceMutation(t *testing.T) { + original := &event.Event{ + ID: id.EventID("$evt"), + Sender: id.UserID("@alice:example.com"), + Content: event.Content{ + Raw: map[string]any{ + "m.relates_to": map[string]any{ + "m.in_reply_to": map[string]any{ + "event_id": "$parent", + }, + }, + }, + }, + } + + snapshot := snapshotPendingEvent(original) + original.ID = id.EventID("$mutated") + original.Sender = id.UserID("@bob:example.com") + original.Content.Raw["m.relates_to"].(map[string]any)["m.in_reply_to"].(map[string]any)["event_id"] = "$other-parent" + + oc := &AIClient{} + meta := modelModeTestMeta("openai/gpt-5.2") + prep, cleanup := oc.prepareStreamingRun(context.Background(), zerolog.Nop(), snapshot, nil, meta) + defer cleanup() + + if prep.State == nil || prep.State.turn == nil { + t.Fatalf("expected streaming turn to be prepared") + } + if got := prep.State.turn.Source().EventID; got != "$evt" { + t.Fatalf("expected snapped source event id %q, got %q", "$evt", got) + } + if got := prep.State.turn.Source().SenderID; got != "@alice:example.com" { + t.Fatalf("expected snapped sender id %q, got %q", "@alice:example.com", got) + } + if got := prep.State.replyTarget.ReplyTo; got != id.EventID("$parent") { + t.Fatalf("expected snapped reply target %q, got %q", "$parent", got) + } +} diff --git a/bridges/ai/streaming_request_tools_test.go b/bridges/ai/streaming_request_tools_test.go index f067123f..2fab20fe 100644 --- a/bridges/ai/streaming_request_tools_test.go +++ b/bridges/ai/streaming_request_tools_test.go @@ -2,48 +2,32 @@ package ai import ( "context" + "slices" "testing" - "time" - - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" ) -func testToolSelectionClient(supportsToolCalling bool) *AIClient { - return &AIClient{ - connector: &OpenAIConnector{ - Config: Config{ - Tools: ToolProvidersConfig{ - Web: &WebToolsConfig{Search: &SearchConfig{ - Exa: ProviderExaConfig{APIKey: "test"}, - }}, - }, - }, - }, - UserLogin: &bridgev2.UserLogin{UserLogin: &database.UserLogin{Metadata: &UserLoginMetadata{ - ModelCache: &ModelCache{ - Models: []ModelInfo{{ - ID: "openai/gpt-5.2", - SupportsToolCalling: supportsToolCalling, - }}, - LastRefresh: time.Now().Unix(), - CacheDuration: 3600, - }, - }}}, - } -} - func TestSelectedStreamingToolDescriptorsSkipsAllToolsWhenModelCannotCallTools(t *testing.T) { - meta := agentModeTestMeta("beeper") - meta.RuntimeModelOverride = "openai/gpt-5.2" + meta := modelModeTestMeta("openai/gpt-5.2") - withTools := testToolSelectionClient(true).selectedStreamingToolDescriptors(context.Background(), meta, false) + withTools := testBuiltinToolClient(true, true, true).selectedStreamingToolDescriptors(context.Background(), meta, false) if len(withTools) == 0 { t.Fatal("expected tool descriptors when tool calling is supported") } - withoutTools := testToolSelectionClient(false).selectedStreamingToolDescriptors(context.Background(), meta, false) + withoutTools := testBuiltinToolClient(false, true, true).selectedStreamingToolDescriptors(context.Background(), meta, false) if len(withoutTools) != 0 { t.Fatalf("expected no tool descriptors when tool calling is unsupported, got %#v", withoutTools) } } + +func TestBuildResponsesAgentLoopParams_ModelRoomUsesModelPreset(t *testing.T) { + meta := modelModeTestMeta("openai/gpt-5.2") + client := testBuiltinToolClient(true, true, true) + + params := client.buildResponsesAgentLoopParams(context.Background(), meta, "system prompt", nil, false) + got := responsesToolNames(params.Tools) + want := []string{toolNameSessionStatus, toolNameWebFetch, ToolNameWebSearch} + if !slices.Equal(got, want) { + t.Fatalf("unexpected response tool list: got %v want %v", got, want) + } +} diff --git a/bridges/ai/streaming_tool_selection.go b/bridges/ai/streaming_tool_selection.go index 1b5e70d0..4b46a378 100644 --- a/bridges/ai/streaming_tool_selection.go +++ b/bridges/ai/streaming_tool_selection.go @@ -2,15 +2,68 @@ package ai import "context" -// selectedBuiltinToolsForTurn returns builtin tools exposed to the model for a turn. -func (oc *AIClient) selectedBuiltinToolsForTurn(ctx context.Context, meta *PortalMetadata) []ToolDefinition { - if meta == nil || !oc.getModelCapabilitiesForMeta(ctx, meta).SupportsToolCalling { +type builtinToolPreset string + +const ( + builtinToolPresetNone builtinToolPreset = "" + builtinToolPresetModel builtinToolPreset = "model" + builtinToolPresetAgent builtinToolPreset = "agent" +) + +var modelChatBuiltinToolNames = []string{ + ToolNameWebSearch, + toolNameWebFetch, + toolNameSessionStatus, +} + +func selectToolDefinitionsByName(available []ToolDefinition, names []string) []ToolDefinition { + if len(available) == 0 || len(names) == 0 { return nil } + availableByName := make(map[string]ToolDefinition, len(available)) + for _, tool := range available { + if tool.Name == "" { + continue + } + availableByName[tool.Name] = tool + } + + selected := make([]ToolDefinition, 0, len(names)) + for _, name := range names { + tool, ok := availableByName[name] + if !ok { + continue + } + selected = append(selected, tool) + } + return selected +} + +func (oc *AIClient) builtinToolPresetForTurn(ctx context.Context, meta *PortalMetadata) builtinToolPreset { + if meta == nil || !oc.getModelCapabilitiesForMeta(ctx, meta).SupportsToolCalling { + return builtinToolPresetNone + } if resolveAgentID(meta) == "" { + return builtinToolPresetModel + } + return builtinToolPresetAgent +} + +// selectedBuiltinToolsForTurn returns builtin tools exposed to the model for a turn. +func (oc *AIClient) selectedBuiltinToolsForTurn(ctx context.Context, meta *PortalMetadata) []ToolDefinition { + preset := oc.builtinToolPresetForTurn(ctx, meta) + if preset == builtinToolPresetNone { return nil } - return oc.enabledBuiltinToolsForModel(ctx, meta) + enabledTools := oc.enabledBuiltinToolsForModel(ctx, meta) + switch preset { + case builtinToolPresetModel: + return selectToolDefinitionsByName(enabledTools, modelChatBuiltinToolNames) + case builtinToolPresetAgent: + return enabledTools + default: + return nil + } } diff --git a/bridges/ai/streaming_tool_selection_test.go b/bridges/ai/streaming_tool_selection_test.go index 9488e415..b9d6322b 100644 --- a/bridges/ai/streaming_tool_selection_test.go +++ b/bridges/ai/streaming_tool_selection_test.go @@ -2,6 +2,7 @@ package ai import ( "context" + "slices" "testing" "time" @@ -9,14 +10,30 @@ import ( "maunium.net/go/mautrix/bridgev2/database" ) -func TestSelectedBuiltinToolsForTurn_AgentRoomExposesBuiltinTools(t *testing.T) { - client := &AIClient{ +func testBuiltinToolClient(supportsToolCalling, searchConfigured, fetchConfigured bool) *AIClient { + searchCfg := &SearchConfig{ + Exa: ProviderExaConfig{Enabled: boolPtr(false)}, + } + if searchConfigured { + searchCfg.Exa = ProviderExaConfig{ + Enabled: boolPtr(true), + APIKey: "test-key", + } + } + + fetchCfg := &FetchConfig{ + Exa: ProviderExaConfig{Enabled: boolPtr(false)}, + Direct: ProviderDirectConfig{Enabled: boolPtr(fetchConfigured)}, + } + + return &AIClient{ connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Web: &WebToolsConfig{Search: &SearchConfig{ - Exa: ProviderExaConfig{APIKey: "test-key"}, - }}, + Web: &WebToolsConfig{ + Search: searchCfg, + Fetch: fetchCfg, + }, }, }, }, @@ -24,13 +41,29 @@ func TestSelectedBuiltinToolsForTurn_AgentRoomExposesBuiltinTools(t *testing.T) ModelCache: &ModelCache{ Models: []ModelInfo{{ ID: "openai/gpt-5.2", - SupportsToolCalling: true, + SupportsToolCalling: supportsToolCalling, }}, LastRefresh: time.Now().Unix(), CacheDuration: 3600, }, }}}, } +} + +func toolDefinitionNames(tools []ToolDefinition) []string { + names := make([]string, 0, len(tools)) + for _, tool := range tools { + if tool.Name == "" { + continue + } + names = append(names, tool.Name) + } + slices.Sort(names) + return names +} + +func TestSelectedBuiltinToolsForTurn_AgentRoomExposesBuiltinTools(t *testing.T) { + client := testBuiltinToolClient(true, true, true) meta := agentModeTestMeta("beeper") meta.RuntimeModelOverride = "openai/gpt-5.2" @@ -41,23 +74,45 @@ func TestSelectedBuiltinToolsForTurn_AgentRoomExposesBuiltinTools(t *testing.T) } } -func TestSelectedBuiltinToolsForTurn_ModelRoomGetsNoTools(t *testing.T) { - client := &AIClient{ - connector: &OpenAIConnector{ - Config: Config{ - Tools: ToolProvidersConfig{ - Web: &WebToolsConfig{Search: &SearchConfig{ - Exa: ProviderExaConfig{APIKey: "test-key"}, - }}, - }, - }, - }, +func TestSelectedBuiltinToolsForTurn_ModelRoomExposesModelPreset(t *testing.T) { + client := testBuiltinToolClient(true, true, true) + meta := modelModeTestMeta("openai/gpt-5.2") + + got := toolDefinitionNames(client.selectedBuiltinToolsForTurn(context.Background(), meta)) + want := []string{toolNameSessionStatus, toolNameWebFetch, ToolNameWebSearch} + if !slices.Equal(got, want) { + t.Fatalf("unexpected model room builtin tools: got %v want %v", got, want) + } +} + +func TestSelectedBuiltinToolsForTurn_ModelRoomOmitsUnavailableWebTools(t *testing.T) { + client := testBuiltinToolClient(true, false, false) + meta := modelModeTestMeta("openai/gpt-5.2") + + got := toolDefinitionNames(client.selectedBuiltinToolsForTurn(context.Background(), meta)) + want := []string{toolNameSessionStatus} + if !slices.Equal(got, want) { + t.Fatalf("unexpected model room builtin tools with web tools disabled: got %v want %v", got, want) } +} + +func TestSelectedBuiltinToolsForTurn_ModelRoomOmitsOnlyUnavailableSearch(t *testing.T) { + client := testBuiltinToolClient(true, false, true) + meta := modelModeTestMeta("openai/gpt-5.2") + + got := toolDefinitionNames(client.selectedBuiltinToolsForTurn(context.Background(), meta)) + want := []string{toolNameSessionStatus, toolNameWebFetch} + if !slices.Equal(got, want) { + t.Fatalf("unexpected model room builtin tools with search disabled: got %v want %v", got, want) + } +} - meta := &PortalMetadata{} +func TestSelectedBuiltinToolsForTurn_ModelRoomWithoutToolCallingGetsNoTools(t *testing.T) { + client := testBuiltinToolClient(false, true, true) + meta := modelModeTestMeta("openai/gpt-5.2") got := client.selectedBuiltinToolsForTurn(context.Background(), meta) if len(got) != 0 { - t.Fatalf("expected no builtin tools when room has no assigned agent, got %d", len(got)) + t.Fatalf("expected no builtin tools when model does not support tool calling, got %d", len(got)) } } From 572fd21a6592bba7b9aadcd29727381068936278 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Mon, 30 Mar 2026 01:47:16 +0200 Subject: [PATCH 16/23] Drop thinking blocks and stop follow-up continuations Remove inline follow-up continuation logic from the agent loop and Responses adapter so turns finalize immediately instead of reopening for queued edits. Add filtering to drop PromptBlockThinking from history and introduce PromptMessage.VisibleText() (and internal text(includeThinking)) so visible-only text is used when constructing outbound inputs. Replay/history APIs now support excluding a specific message ID and prepareInboundPromptContext appends replayed history and builds the system prompt accordingly. Tests added/updated for config example sync, prompt history/visibility, and message-to-responses behavior. Also update config.example.yaml with the revised network/tools/agents schema and related defaults. --- bridges/ai/agent_loop_test.go | 32 +- bridges/ai/canonical_prompt_messages.go | 2 + bridges/ai/client.go | 9 +- bridges/ai/config_example_sync_test.go | 52 +++ bridges/ai/messages.go | 22 +- bridges/ai/messages_responses_input_test.go | 22 ++ bridges/ai/prompt_builder.go | 8 +- bridges/ai/prompt_context_local.go | 4 +- bridges/ai/prompt_history_test.go | 32 ++ bridges/ai/streaming_executor.go | 24 +- bridges/ai/streaming_responses_api.go | 12 +- config.example.yaml | 360 ++++++++------------ 12 files changed, 282 insertions(+), 297 deletions(-) create mode 100644 bridges/ai/config_example_sync_test.go create mode 100644 bridges/ai/prompt_history_test.go diff --git a/bridges/ai/agent_loop_test.go b/bridges/ai/agent_loop_test.go index 0dd3afaf..cd1e6412 100644 --- a/bridges/ai/agent_loop_test.go +++ b/bridges/ai/agent_loop_test.go @@ -11,9 +11,7 @@ import ( type fakeAgentLoopProvider struct { track bool results []fakeAgentLoopResult - followUps map[int][]PromptMessage finalizeCalls int - continueCalls int roundsObserved []int } @@ -40,19 +38,6 @@ func (f *fakeAgentLoopProvider) FinalizeAgentLoop(context.Context) { f.finalizeCalls++ } -func (f *fakeAgentLoopProvider) GetFollowUpMessages(_ context.Context) []PromptMessage { - if len(f.roundsObserved) == 0 { - return nil - } - return f.followUps[f.roundsObserved[len(f.roundsObserved)-1]] -} - -func (f *fakeAgentLoopProvider) ContinueAgentLoop(messages []PromptMessage) { - if len(messages) > 0 { - f.continueCalls++ - } -} - func TestExecuteAgentLoopRoundsFinalizesOnTerminalTurn(t *testing.T) { provider := &fakeAgentLoopProvider{ results: []fakeAgentLoopResult{ @@ -125,20 +110,10 @@ func TestExecuteAgentLoopRoundsStopsOnContextLengthWithFinalize(t *testing.T) { } } -func TestExecuteAgentLoopRoundsContinuesForFollowUpMessages(t *testing.T) { +func TestExecuteAgentLoopRoundsDoesNotInlineFollowUpMessages(t *testing.T) { provider := &fakeAgentLoopProvider{ results: []fakeAgentLoopResult{ {continueLoop: false}, - {continueLoop: false}, - }, - followUps: map[int][]PromptMessage{ - 0: {{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: "follow up", - }}, - }}, }, } @@ -152,13 +127,10 @@ func TestExecuteAgentLoopRoundsContinuesForFollowUpMessages(t *testing.T) { if err != nil { t.Fatalf("expected no error, got %v", err) } - if provider.continueCalls != 1 { - t.Fatalf("expected one follow-up continuation, got %d", provider.continueCalls) - } if provider.finalizeCalls != 1 { t.Fatalf("expected finalize once, got %d", provider.finalizeCalls) } - if len(provider.roundsObserved) != 2 || provider.roundsObserved[0] != 0 || provider.roundsObserved[1] != 1 { + if len(provider.roundsObserved) != 1 || provider.roundsObserved[0] != 0 { t.Fatalf("unexpected rounds observed: %#v", provider.roundsObserved) } } diff --git a/bridges/ai/canonical_prompt_messages.go b/bridges/ai/canonical_prompt_messages.go index e263627f..4ff88b83 100644 --- a/bridges/ai/canonical_prompt_messages.go +++ b/bridges/ai/canonical_prompt_messages.go @@ -41,6 +41,8 @@ func filterPromptBlocksForHistory(blocks []PromptBlock, injectImages bool) []Pro if injectImages { filtered = append(filtered, block) } + case PromptBlockThinking: + continue default: filtered = append(filtered, block) } diff --git a/bridges/ai/client.go b/bridges/ai/client.go index e53ee8ed..d408ae47 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -1743,10 +1743,17 @@ func (oc *AIClient) prepareInboundPromptContext( userText string, eventID id.EventID, ) (inboundPromptResult, error) { - promptContext, err := oc.buildBaseContext(ctx, portal, meta) + promptContext := PromptContext{ + SystemPrompt: oc.buildConversationSystemPromptText(ctx, portal, meta, true), + } + historyMessages, err := oc.replayHistoryMessages(ctx, portal, meta, historyReplayOptions{ + mode: historyReplayNormal, + excludeMessageID: networkid.MessageID(eventID), + }) if err != nil { return inboundPromptResult{}, err } + promptContext.Messages = append(promptContext.Messages, historyMessages...) inboundCtx := oc.resolvePromptInboundContext(ctx, portal, userText, eventID) AppendPromptText(&promptContext.SystemPrompt, airuntime.BuildInboundMetaSystemPrompt(inboundCtx)) diff --git a/bridges/ai/config_example_sync_test.go b/bridges/ai/config_example_sync_test.go new file mode 100644 index 00000000..1d275655 --- /dev/null +++ b/bridges/ai/config_example_sync_test.go @@ -0,0 +1,52 @@ +package ai + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "testing" + + "gopkg.in/yaml.v3" +) + +func TestMainConfigExampleNetworkBlockMatchesEmbeddedExample(t *testing.T) { + mainConfigPath := filepath.Join("..", "..", "config.example.yaml") + mainData, err := os.ReadFile(mainConfigPath) + if err != nil { + t.Fatalf("read %s: %v", mainConfigPath, err) + } + + var mainDoc map[string]any + if err := yaml.Unmarshal(mainData, &mainDoc); err != nil { + t.Fatalf("unmarshal %s: %v", mainConfigPath, err) + } + + networkRaw, ok := mainDoc["network"] + if !ok { + t.Fatalf("%s is missing top-level network block", mainConfigPath) + } + network, ok := networkRaw.(map[string]any) + if !ok { + t.Fatalf("%s network block has unexpected type %T", mainConfigPath, networkRaw) + } + + embeddedPath := "integrations_example-config.yaml" + embeddedData, err := os.ReadFile(embeddedPath) + if err != nil { + t.Fatalf("read %s: %v", embeddedPath, err) + } + + var embeddedDoc map[string]any + if err := yaml.Unmarshal(embeddedData, &embeddedDoc); err != nil { + t.Fatalf("unmarshal %s: %v", embeddedPath, err) + } + + if reflect.DeepEqual(network, embeddedDoc) { + return + } + + gotJSON, _ := json.MarshalIndent(network, "", " ") + wantJSON, _ := json.MarshalIndent(embeddedDoc, "", " ") + t.Fatalf("config.example.yaml network block drifted from %s\n--- got ---\n%s\n--- want ---\n%s", embeddedPath, gotJSON, wantJSON) +} diff --git a/bridges/ai/messages.go b/bridges/ai/messages.go index fd165079..66a0cb54 100644 --- a/bridges/ai/messages.go +++ b/bridges/ai/messages.go @@ -41,11 +41,21 @@ type PromptMessage struct { IsError bool } -func (m PromptMessage) Text() string { +func (m PromptMessage) text(includeThinking bool) string { var sb strings.Builder for _, block := range m.Blocks { switch block.Type { - case PromptBlockText, PromptBlockThinking: + case PromptBlockText: + if block.Text != "" { + if sb.Len() > 0 { + sb.WriteByte('\n') + } + sb.WriteString(block.Text) + } + case PromptBlockThinking: + if !includeThinking || block.Text == "" { + continue + } if block.Text != "" { if sb.Len() > 0 { sb.WriteByte('\n') @@ -57,6 +67,14 @@ func (m PromptMessage) Text() string { return sb.String() } +func (m PromptMessage) Text() string { + return m.text(true) +} + +func (m PromptMessage) VisibleText() string { + return m.text(false) +} + // PromptContext is the bridge-local prompt envelope used throughout bridges/ai. type PromptContext struct { SystemPrompt string diff --git a/bridges/ai/messages_responses_input_test.go b/bridges/ai/messages_responses_input_test.go index 715889ea..081bf01d 100644 --- a/bridges/ai/messages_responses_input_test.go +++ b/bridges/ai/messages_responses_input_test.go @@ -49,3 +49,25 @@ func TestPromptContextToResponsesInput_MultimodalUser(t *testing.T) { t.Fatalf("expected text and image parts (got text=%v image=%v)", foundText, foundImage) } } + +func TestPromptContextToResponsesInput_AssistantOmitsThinkingBlocks(t *testing.T) { + input := promptContextToResponsesInput(PromptContext{ + Messages: []PromptMessage{{ + Role: PromptRoleAssistant, + Blocks: []PromptBlock{ + {Type: PromptBlockThinking, Text: "internal analysis"}, + {Type: PromptBlockText, Text: "visible reply"}, + }, + }}, + }) + if len(input) != 1 { + t.Fatalf("expected 1 input item, got %d", len(input)) + } + item := input[0].OfMessage + if item == nil { + t.Fatalf("expected assistant message input") + } + if got := item.Content.OfString.Value; got != "visible reply" { + t.Fatalf("expected only visible reply text, got %q", got) + } +} diff --git a/bridges/ai/prompt_builder.go b/bridges/ai/prompt_builder.go index f0423a44..5a4d1774 100644 --- a/bridges/ai/prompt_builder.go +++ b/bridges/ai/prompt_builder.go @@ -22,8 +22,9 @@ const ( ) type historyReplayOptions struct { - mode historyReplayMode - targetMessageID networkid.MessageID + mode historyReplayMode + targetMessageID networkid.MessageID + excludeMessageID networkid.MessageID } type currentTurnTextOptions struct { @@ -110,6 +111,9 @@ func (oc *AIClient) replayHistoryMessages( candidates := make([]replayCandidate, 0, len(hr.rows)) for _, row := range hr.rows { + if opts.excludeMessageID != "" && row.ID == opts.excludeMessageID { + continue + } msgMeta := messageMeta(row) if opts.mode == historyReplayRewrite && row.ID == opts.targetMessageID { candidates = append(candidates, replayCandidate{row: row, meta: msgMeta}) diff --git a/bridges/ai/prompt_context_local.go b/bridges/ai/prompt_context_local.go index 08efd8da..b8a2f550 100644 --- a/bridges/ai/prompt_context_local.go +++ b/bridges/ai/prompt_context_local.go @@ -92,7 +92,7 @@ func promptMessageToResponsesInputs(msg PromptMessage) responses.ResponseInputPa }} case PromptRoleAssistant: var result responses.ResponseInputParam - text := strings.TrimSpace(msg.Text()) + text := strings.TrimSpace(msg.VisibleText()) if text != "" { result = append(result, responses.ResponseInputItemUnionParam{ OfMessage: &responses.EasyInputMessageParam{ @@ -191,7 +191,7 @@ func promptAssistantToChatMessage(msg PromptMessage) *openai.ChatCompletionAssis var toolCalls []openai.ChatCompletionMessageToolCallUnionParam for _, block := range msg.Blocks { switch block.Type { - case PromptBlockText, PromptBlockThinking: + case PromptBlockText: text := strings.TrimSpace(block.Text) if text == "" { continue diff --git a/bridges/ai/prompt_history_test.go b/bridges/ai/prompt_history_test.go new file mode 100644 index 00000000..372c591a --- /dev/null +++ b/bridges/ai/prompt_history_test.go @@ -0,0 +1,32 @@ +package ai + +import "testing" + +func TestFilterPromptBlocksForHistoryDropsThinking(t *testing.T) { + filtered := filterPromptBlocksForHistory([]PromptBlock{ + {Type: PromptBlockThinking, Text: "internal analysis"}, + {Type: PromptBlockText, Text: "visible reply"}, + }, false) + if len(filtered) != 1 { + t.Fatalf("expected 1 block after filtering, got %d", len(filtered)) + } + if filtered[0].Type != PromptBlockText || filtered[0].Text != "visible reply" { + t.Fatalf("unexpected filtered blocks: %#v", filtered) + } +} + +func TestPromptMessageVisibleTextOmitsThinking(t *testing.T) { + msg := PromptMessage{ + Role: PromptRoleAssistant, + Blocks: []PromptBlock{ + {Type: PromptBlockThinking, Text: "internal analysis"}, + {Type: PromptBlockText, Text: "visible reply"}, + }, + } + if got := msg.VisibleText(); got != "visible reply" { + t.Fatalf("expected visible text only, got %q", got) + } + if got := msg.Text(); got != "internal analysis\nvisible reply" { + t.Fatalf("expected full text to retain thinking, got %q", got) + } +} diff --git a/bridges/ai/streaming_executor.go b/bridges/ai/streaming_executor.go index 22aaa101..303d2956 100644 --- a/bridges/ai/streaming_executor.go +++ b/bridges/ai/streaming_executor.go @@ -13,8 +13,6 @@ import ( type agentLoopProvider interface { TrackRoomRunStreaming() bool RunAgentTurn(ctx context.Context, evt *event.Event, round int) (continueLoop bool, cle *ContextLengthError, err error) - GetFollowUpMessages(ctx context.Context) []PromptMessage - ContinueAgentLoop(messages []PromptMessage) FinalizeAgentLoop(ctx context.Context) } @@ -51,20 +49,6 @@ func newAgentLoopProviderBase( } } -func (a *agentLoopProviderBase) GetFollowUpMessages(context.Context) []PromptMessage { - if a == nil || a.oc == nil || a.state == nil { - return nil - } - return a.oc.getFollowUpMessages(a.state.roomID) -} - -func (a *agentLoopProviderBase) ContinueAgentLoop(messages []PromptMessage) { - if a == nil || len(messages) == 0 { - return - } - a.prompt.Messages = append(a.prompt.Messages, messages...) -} - func (oc *AIClient) runAgentLoop( ctx context.Context, log zerolog.Logger, @@ -105,12 +89,8 @@ func executeAgentLoopRounds( continue } - followUpMessages := provider.GetFollowUpMessages(ctx) - if len(followUpMessages) > 0 { - provider.ContinueAgentLoop(followUpMessages) - continue - } - + // Queued user messages are dispatched after room release via processPendingQueue. + // Finalize this turn immediately so later prompts cannot reopen it with more edits. finalizeAgentLoopExit(ctx, provider, false) return true, nil, nil } diff --git a/bridges/ai/streaming_responses_api.go b/bridges/ai/streaming_responses_api.go index ab5e25bb..20b682d8 100644 --- a/bridges/ai/streaming_responses_api.go +++ b/bridges/ai/streaming_responses_api.go @@ -27,7 +27,6 @@ type responsesTurnAdapter struct { agentLoopProviderBase params responses.ResponseNewParams initialized bool - hasFollowUp bool rsc *responseStreamContext } @@ -92,7 +91,6 @@ func (a *responsesTurnAdapter) startContinuationRound(ctx context.Context) (*sse if stream == nil { return nil, continuationParams, errors.New("continuation streaming not available") } - a.hasFollowUp = false state.clearContinuationState() return stream, continuationParams, nil } @@ -117,7 +115,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( return false, nil, &PreDeltaError{Err: err} } } else { - if len(state.pendingFunctionOutputs) == 0 && len(state.pendingMcpApprovals) == 0 && !a.hasFollowUp { + if len(state.pendingFunctionOutputs) == 0 && len(state.pendingMcpApprovals) == 0 { return false, nil, nil } if round > maxAgentLoopToolTurns { @@ -178,14 +176,6 @@ func (a *responsesTurnAdapter) FinalizeAgentLoop(ctx context.Context) { a.oc.finalizeResponsesStream(ctx, a.log, a.portal, a.state, a.meta) } -func (a *responsesTurnAdapter) ContinueAgentLoop(messages []PromptMessage) { - if len(messages) == 0 { - return - } - a.prompt.Messages = append(a.prompt.Messages, messages...) - a.hasFollowUp = true -} - // processResponseStreamEvent handles a single Responses API stream event. // Returns done=true when the caller's loop should break (error/fatal), along with // any context-length error or general error. The caller is responsible for diff --git a/config.example.yaml b/config.example.yaml index e5c6a761..6fe29f82 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -41,82 +41,104 @@ encryption: # AI Chats-specific options (shared with the embedded example in bridges/ai/integrations_config.go) network: - # Beeper Cloud credentials for automatic login (optional) - beeper: - user_mxid: "" # Owning Matrix user for the built-in Beeper Cloud login. - base_url: "" # Optional. If empty, login uses selected Beeper domain. - token: "" # Beeper Matrix access token + # Connector-specific configuration lives under the `network:` section of the + # main config file. - # Per-provider default models - providers: - beeper: - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - openai: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://api.openai.com/v1 - base_url: "https://api.openai.com/v1" - default_model: "openai/gpt-5.2" - openrouter: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://openrouter.ai/api/v1 - base_url: "https://openrouter.ai/api/v1" - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" + beeper: + user_mxid: "" + base_url: "" + token: "" -# Optional model catalog seeding (OpenClaw-style). -# models: -# mode: "merge" # merge | replace -# providers: -# openai: -# models: -# - id: "gpt-5.2" -# name: "GPT-5.2" -# reasoning: true -# input: ["text", "image"] -# context_window: 128000 -# max_tokens: 8192 + models: + providers: + openai: + api_key: "" + base_url: "https://api.openai.com/v1" + models: [] + openrouter: + api_key: "" + base_url: "https://openrouter.ai/api/v1" + models: [] + magic_proxy: + api_key: "" + base_url: "" + models: [] -# Global settings -default_system_prompt: | - You are a helpful, concise assistant. + default_system_prompt: | + You are a helpful, concise assistant. Ask clarifying questions when needed. Follow the user's intent and be accurate. model_cache_duration: 6h - # External tool providers (search + fetch). Proxy is optional. - tools: - search: - provider: "exa" - fallbacks: [] - exa: - api_key: "" - base_url: "https://api.exa.ai" - type: "auto" - num_results: 5 - include_text: false - text_max_chars: 500 - fetch: - provider: "exa" - fallbacks: ["direct"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - include_text: true - text_max_chars: 5000 - direct: - enabled: true - timeout_seconds: 30 - max_chars: 50000 - max_redirects: 3 + messages: + direct_chat: + history_limit: 20 + group_chat: + history_limit: 50 + queue: + mode: "collect" + debounce_ms: 1000 + cap: 20 + drop: "summarize" + + commands: + owner_allow_from: [] + + tool_approvals: + enabled: true + ttl_seconds: 600 + require_for_mcp: true + require_for_tools: ["message", "cron", "gravatar_set", "create_agent", "fork_agent", "edit_agent", "delete_agent", "modify_room", "sessions_send", "sessions_spawn", "run_internal_command"] + + channels: + matrix: + reply_to_mode: "first" + + session: + scope: "per-sender" + main_key: "main" - # Media understanding/transcription (OpenClaw-style). + tools: + web: + search: + provider: "exa" + fallbacks: [] + exa: + api_key: "" + base_url: "https://api.exa.ai" + type: "auto" + num_results: 5 + include_text: false + text_max_chars: 500 + highlights: true + fetch: + provider: "direct" + fallbacks: ["exa"] + exa: + api_key: "" + base_url: "https://api.exa.ai" + include_text: true + text_max_chars: 5000 + direct: + enabled: true + timeout_seconds: 30 + max_chars: 50000 + max_redirects: 3 + links: + enabled: true + max_urls_inbound: 3 + max_urls_outbound: 5 + fetch_timeout: 10s + max_content_chars: 500 + max_page_bytes: 10485760 + max_image_bytes: 5242880 + cache_ttl: 1h + mcp: + enable_stdio: false + vfs: + apply_patch: + enabled: false + allow_models: [] media: concurrency: 2 image: @@ -132,7 +154,6 @@ default_system_prompt: | enabled: true prompt: "Transcribe the audio." language: "" - # CLI transcription auto-detection (whisper/whisper.cpp) is not implemented yet. max_bytes: 20971520 timeout_seconds: 60 models: @@ -147,161 +168,46 @@ default_system_prompt: | - provider: "openrouter" model: "google/gemini-3-flash-preview" - # Memory search configuration (OpenClaw-style). - # Indexes MEMORY.md + memory/*.md stored in the bridge DB. - # Per-agent overrides can be set via agent definitions. - # Current runtime behavior is lexical-only. - memory_search: - enabled: true - sources: ["memory"] - extra_paths: [] - local: - model_path: "" - model_cache_dir: "" - base_url: "" - api_key: "" - store: - driver: "sqlite" - path: "" - chunking: - tokens: 400 - overlap: 80 - sync: - on_session_start: true - on_search: true - watch: true - watch_debounce_ms: 1500 - interval_minutes: 0 - sessions: - delta_bytes: 100000 - delta_messages: 50 - query: - max_results: 6 - min_score: 0.35 - hybrid: - candidate_multiplier: 4 - cache: - enabled: true - max_entries: -1 # Unlimited when cache.enabled is true; experimental.session_memory does not change cache size semantics. - experimental: - session_memory: false - - # Tool policy (OpenClaw-style). Controls allow/deny lists and profiles. - # tool_policy: - # profile: "full" - # # group:openclaw is the strict OpenClaw native tool set. - # # group:ai-bridge includes ai-bridge-only extras (beeper_docs, gravatar_*, tts, image_generate, calculator, etc). - # allow: ["group:openclaw", "group:ai-bridge"] - # deny: [] - # subagents: - # tools: - # deny: ["sessions_list", "sessions_history", "sessions_send"] - - # Agent defaults (OpenClaw-style). - # agents: - # defaults: - # subagents: - # model: "anthropic/claude-sonnet-4.5" - # allow_agents: ["*"] - # typing_mode: "instant" # never|instant|thinking|message (message ignores NO_REPLY) - # typing_interval_seconds: 6 # refresh cadence, not start time (heartbeats never show typing) - - # Context pruning configuration (OpenClaw-style). - # Reduces token usage by intelligently truncating old tool results. - pruning: - # Enable proactive context pruning - enabled: true - - # Ratio of context window usage that triggers soft trimming (0.0-1.0) - # At 30% usage, large tool results start getting truncated - soft_trim_ratio: 0.3 - - # Ratio of context window usage that triggers hard clearing (0.0-1.0) - # At 50% usage, old tool results are replaced with placeholder - hard_clear_ratio: 0.5 - - # Number of recent assistant messages to protect from pruning - keep_last_assistants: 3 - - # Minimum total chars in prunable tool results before hard clear kicks in - min_prunable_chars: 50000 - - # Tool results larger than this are candidates for soft trimming - soft_trim_max_chars: 4000 - - # When soft trimming, keep this many chars from the start - soft_trim_head_chars: 1500 - - # When soft trimming, keep this many chars from the end - soft_trim_tail_chars: 1500 - - # Enable/disable hard clear phase - hard_clear_enabled: true - - # Placeholder text for hard-cleared tool results - hard_clear_placeholder: "[Old tool result content cleared]" - - # Tool patterns to allow/deny pruning (supports wildcards: exec*, *_search) - # Empty means all tools are prunable unless denied - # tools_allow: [] - # tools_deny: [] - - # --- LLM-based summarization (compaction) --- - # When enabled, uses an LLM to generate intelligent summaries of compacted - # content instead of just using placeholder text. This preserves context better. - - # Enable LLM summarization (default: true when pruning is enabled) - summarization_enabled: true - - # Model to use for generating summaries (default: fast model) - summarization_model: "openai/gpt-5.2" - - # Maximum tokens for generated summaries - max_summary_tokens: 500 - - # Maximum ratio of context that history can consume (0.0-1.0) - # When exceeded, oldest messages are summarized to fit budget - max_history_share: 0.5 - - # Token budget reserved for compaction output - reserve_tokens: 20000 - - # Additional instructions for the summarization model - # custom_instructions: "Focus on preserving code decisions and TODOs" - - # Identifier preservation policy for summaries: - # - strict (default): preserve opaque identifiers exactly - # - off: no special identifier-preservation instruction - # - custom: use identifier_instructions below - identifier_policy: "strict" - # identifier_instructions: "Keep ticket IDs, hashes, and hostnames unchanged." - - # Optional pre-compaction overflow flush turn. - # Disabled by default to keep compaction independent from memory workflows. - overflow_flush: - enabled: false - soft_threshold_tokens: 4000 - prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." - system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." - - # Link preview configuration. - # Automatically fetches metadata for URLs in messages to provide context to the AI - # and generate rich previews in outgoing AI responses. - link_previews: - # Enable link preview functionality (default: true) - enabled: true - - # Maximum number of URLs to fetch from user messages for AI context (default: 3) - max_urls_inbound: 3 - - # Maximum number of URLs to preview in AI responses (default: 5) - max_urls_outbound: 5 - - # Timeout for fetching each URL (default: 10s) - fetch_timeout: 10s - - # Maximum characters from description to include in context (default: 500) - max_content_chars: 500 - - # Maximum page size to download in bytes (default: 10MB) - max_page_bytes: 10485760 + agents: + defaults: + model: + primary: "" + fallbacks: [] + image_model: + primary: "" + fallbacks: [] + image_generation_model: + primary: "" + fallbacks: [] + pdf_model: + primary: "" + fallbacks: [] + pdf_engine: "mistral-ocr" + compaction: + mode: "cache-ttl" + ttl: "1h" + enabled: true + soft_trim_ratio: 0.3 + hard_clear_ratio: 0.5 + keep_last_assistants: 3 + min_prunable_chars: 50000 + soft_trim_max_chars: 4000 + soft_trim_head_chars: 1500 + soft_trim_tail_chars: 1500 + hard_clear_enabled: true + hard_clear_placeholder: "[Old tool result content cleared]" + summarization_enabled: true + summarization_model: "openai/gpt-5.2" + max_summary_tokens: 500 + compaction_mode: "safeguard" + keep_recent_tokens: 20000 + max_history_share: 0.5 + reserve_tokens: 20000 + reserve_tokens_floor: 20000 + identifier_policy: "strict" + post_compaction_refresh_prompt: "[Post-compaction context refresh]\nRe-anchor to the latest user intent and preserve unresolved tasks and identifiers." + overflow_flush: + enabled: true + soft_threshold_tokens: 4000 + prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." + system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." From 07c8555d98766473da928a33f4eecf98586c09bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Mon, 30 Mar 2026 02:02:18 +0200 Subject: [PATCH 17/23] Honor pending queue; disable MagicProxy image API Do not start a room run when there is pending queue work: add roomHasPendingQueueWork, factor it into queue decision logic, clear roomBusy on interrupt, and only acquire the room when not busy. Add a test to verify new queuing-behind-backlog behavior. For image generation, treat Magic Proxy as not exposing Gemini/OpenRouter image endpoints (return false) and update tests to expect routing to OpenAI and gemini unavailability accordingly. --- bridges/ai/client.go | 6 ++- bridges/ai/image_generation_tool.go | 19 ++----- .../image_generation_tool_magic_proxy_test.go | 25 ++++------ bridges/ai/pending_queue.go | 13 +++++ bridges/ai/queue_status_test.go | 50 +++++++++++++++++++ 5 files changed, 83 insertions(+), 30 deletions(-) diff --git a/bridges/ai/client.go b/bridges/ai/client.go index d408ae47..842dcac5 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -637,12 +637,14 @@ func (oc *AIClient) dispatchOrQueueCore( shouldSteer := behavior.Steer shouldFollowup := behavior.Followup hasDBMessage := userMessage != nil - queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, oc.roomHasActiveRun(roomID), false) + roomBusy := oc.roomHasActiveRun(roomID) || oc.roomHasPendingQueueWork(roomID) + queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, roomBusy, false) if queueDecision.Action == airuntime.QueueActionInterruptAndRun { oc.cancelRoomRun(roomID) oc.clearPendingQueue(roomID) + roomBusy = false } - if oc.acquireRoom(roomID) { + if !roomBusy && oc.acquireRoom(roomID) { oc.stopQueueTyping(roomID) if hasDBMessage { oc.saveUserMessage(ctx, evt, userMessage) diff --git a/bridges/ai/image_generation_tool.go b/bridges/ai/image_generation_tool.go index c72423d9..70314fe3 100644 --- a/bridges/ai/image_generation_tool.go +++ b/bridges/ai/image_generation_tool.go @@ -255,14 +255,8 @@ func supportsGeminiImageGen(btc *BridgeToolContext) bool { } switch loginMeta.Provider { case ProviderMagicProxy: - if btc.Client.connector != nil { - services := btc.Client.connector.resolveServiceConfig(loginMeta) - if svc, ok := services[serviceGemini]; ok { - return strings.TrimSpace(svc.BaseURL) != "" && strings.TrimSpace(svc.APIKey) != "" - } - } - base := normalizeProxyBaseURL(loginCredentialBaseURL(loginMeta)) - return base != "" && loginCredentialAPIKey(loginMeta) != "" + // Magic Proxy does not expose the Gemini image generation endpoint. + return false default: return false } @@ -592,12 +586,9 @@ func resolveOpenRouterImageGenEndpoint(btc *BridgeToolContext) (baseURL string, // Provider-specific per-login endpoints. switch meta.Provider { case ProviderMagicProxy: - base := normalizeProxyBaseURL(loginCredentialBaseURL(meta)) - key := trim(loginCredentialAPIKey(meta)) - if base == "" || key == "" { - return "", "", false - } - return joinProxyPath(base, "/openrouter/v1"), key, true + // Magic Proxy does not expose the OpenRouter images endpoint; use the + // verified OpenAI images route instead. + return "", "", false case ProviderOpenRouter: if conn == nil { return "", "", false diff --git a/bridges/ai/image_generation_tool_magic_proxy_test.go b/bridges/ai/image_generation_tool_magic_proxy_test.go index e425df0a..2c1d3864 100644 --- a/bridges/ai/image_generation_tool_magic_proxy_test.go +++ b/bridges/ai/image_generation_tool_magic_proxy_test.go @@ -2,7 +2,7 @@ package ai import "testing" -func TestResolveImageGenProviderMagicProxyPrefersOpenRouterForSimplePrompts(t *testing.T) { +func TestResolveImageGenProviderMagicProxyPrefersOpenAIForSimplePrompts(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, Credentials: &LoginCredentials{ @@ -19,12 +19,12 @@ func TestResolveImageGenProviderMagicProxyPrefersOpenRouterForSimplePrompts(t *t if err != nil { t.Fatalf("resolveImageGenProvider returned error: %v", err) } - if got != imageGenProviderOpenRouter { - t.Fatalf("expected provider %q, got %q", imageGenProviderOpenRouter, got) + if got != imageGenProviderOpenAI { + t.Fatalf("expected provider %q, got %q", imageGenProviderOpenAI, got) } } -func TestResolveImageGenProviderMagicProxyStillPrefersOpenRouterWhenCountIsGreaterThanOne(t *testing.T) { +func TestResolveImageGenProviderMagicProxyStillPrefersOpenAIWhenCountIsGreaterThanOne(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, Credentials: &LoginCredentials{ @@ -41,12 +41,12 @@ func TestResolveImageGenProviderMagicProxyStillPrefersOpenRouterWhenCountIsGreat if err != nil { t.Fatalf("resolveImageGenProvider returned error: %v", err) } - if got != imageGenProviderOpenRouter { - t.Fatalf("expected provider %q, got %q", imageGenProviderOpenRouter, got) + if got != imageGenProviderOpenAI { + t.Fatalf("expected provider %q, got %q", imageGenProviderOpenAI, got) } } -func TestResolveImageGenProviderMagicProxyProviderOpenAIStillRoutesToOpenRouter(t *testing.T) { +func TestResolveImageGenProviderMagicProxyProviderOpenAIUsesOpenAI(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, Credentials: &LoginCredentials{ @@ -69,7 +69,7 @@ func TestResolveImageGenProviderMagicProxyProviderOpenAIStillRoutesToOpenRouter( } } -func TestResolveImageGenProviderMagicProxyProviderGeminiUsesGemini(t *testing.T) { +func TestResolveImageGenProviderMagicProxyProviderGeminiIsUnavailable(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, Credentials: &LoginCredentials{ @@ -79,16 +79,13 @@ func TestResolveImageGenProviderMagicProxyProviderGeminiUsesGemini(t *testing.T) } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) - got, err := resolveImageGenProvider(imageGenRequest{ + _, err := resolveImageGenProvider(imageGenRequest{ Provider: "gemini", Prompt: "cat", Count: 1, }, btc) - if err != nil { - t.Fatalf("resolveImageGenProvider returned error: %v", err) - } - if got != imageGenProviderGemini { - t.Fatalf("expected provider %q, got %q", imageGenProviderGemini, got) + if err == nil { + t.Fatal("expected gemini image generation to be unavailable for magic proxy") } } diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index d2b40089..3332f3d8 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -176,6 +176,19 @@ func (oc *AIClient) getQueueSnapshot(roomID id.RoomID) *pendingQueue { return &clone } +func (oc *AIClient) roomHasPendingQueueWork(roomID id.RoomID) bool { + if oc == nil || roomID == "" { + return false + } + oc.pendingQueuesMu.Lock() + defer oc.pendingQueuesMu.Unlock() + queue := oc.pendingQueues[roomID] + if queue == nil { + return false + } + return queue.draining || len(queue.items) > 0 || queue.droppedCount > 0 +} + func (oc *AIClient) consumeQueueSummary(roomID id.RoomID, noun string) string { oc.pendingQueuesMu.Lock() defer oc.pendingQueuesMu.Unlock() diff --git a/bridges/ai/queue_status_test.go b/bridges/ai/queue_status_test.go index 8be93e96..b30b6092 100644 --- a/bridges/ai/queue_status_test.go +++ b/bridges/ai/queue_status_test.go @@ -140,3 +140,53 @@ func TestDispatchOrQueueQueueAcceptReturnsPending(t *testing.T) { t.Fatalf("expected queue length 1 after accept, got %d", got) } } + +func TestDispatchOrQueueQueuesBehindExistingPendingWork(t *testing.T) { + roomID := id.RoomID("!room:example.com") + oc := &AIClient{ + activeRooms: map[id.RoomID]bool{}, + pendingQueues: map[id.RoomID]*pendingQueue{}, + } + oc.pendingQueues[roomID] = &pendingQueue{ + items: []pendingQueueItem{ + { + pending: pendingMessage{Type: pendingTypeText, MessageBody: "older"}, + }, + }, + cap: 10, + dropPolicy: airuntime.QueueDropOld, + } + + evt := &event.Event{ID: id.EventID("$new")} + portal := &bridgev2.Portal{Portal: &database.Portal{}} + portal.MXID = roomID + queueItem := pendingQueueItem{ + pending: pendingMessage{Type: pendingTypeText, MessageBody: "new"}, + messageID: string(evt.ID), + } + + _, isPending := oc.dispatchOrQueue( + context.Background(), + evt, + portal, + nil, + nil, + queueItem, + airuntime.QueueSettings{Mode: airuntime.QueueModeCollect, Cap: 10, DropPolicy: airuntime.QueueDropOld}, + PromptContext{}, + ) + + if !isPending { + t.Fatalf("expected pending=true when older queued work exists") + } + queue := oc.pendingQueues[roomID] + if queue == nil { + t.Fatalf("expected pending queue to exist") + } + if got := len(queue.items); got != 2 { + t.Fatalf("expected queue length 2 after enqueue behind backlog, got %d", got) + } + if oc.activeRooms[roomID] { + t.Fatalf("expected room to remain unacquired while backlog exists") + } +} From 7966e20545ab494b5417672422bce7119b2c1265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Mon, 30 Mar 2026 02:56:33 +0200 Subject: [PATCH 18/23] Regenerate AI model manifest and allowlist Regenerated bridges/ai/beeper_models_generated.go with an updated timestamp. The manifest was refreshed: many models were removed, added or renamed, and several model capabilities (API, Supports*, ContextWindow, MaxOutputTokens, AvailableTools) were adjusted. Aliases (e.g. beeper/fast, beeper/reasoning, beeper/smart) were updated to new targets. Updated the allowlist expectations in bridges/ai/beeper_models_manifest_test.go to match the new manifest. --- bridges/ai/beeper_models_generated.go | 903 ++-------------- bridges/ai/beeper_models_manifest_test.go | 97 +- bridges/ai/chat_login_redirect_test.go | 18 +- bridges/ai/client.go | 17 +- bridges/ai/client_capabilities_test.go | 12 +- bridges/ai/handlematrix.go | 10 +- bridges/ai/handlematrix_edit_test.go | 61 ++ bridges/ai/image_generation_tool.go | 33 +- .../image_generation_tool_magic_proxy_test.go | 44 + bridges/ai/model_catalog_test.go | 37 + cmd/generate-models/main.go | 133 +-- generate-models.sh | 12 +- pkg/agents/presets.go | 4 +- pkg/ai/beeper_models.json | 979 +++--------------- 14 files changed, 545 insertions(+), 1815 deletions(-) create mode 100644 bridges/ai/handlematrix_edit_test.go diff --git a/bridges/ai/beeper_models_generated.go b/bridges/ai/beeper_models_generated.go index 10f9f707..bf5e0348 100644 --- a/bridges/ai/beeper_models_generated.go +++ b/bridges/ai/beeper_models_generated.go @@ -1,5 +1,5 @@ // Code generated by generate-models. DO NOT EDIT. -// Generated at: 2026-03-08T11:58:59Z +// Generated at: 2026-03-30T00:25:29Z package ai @@ -27,40 +27,6 @@ var ModelManifest = struct { MaxOutputTokens: 64000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "anthropic/claude-opus-4.1": { - ID: "anthropic/claude-opus-4.1", - Name: "Claude 4.1 Opus", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 32000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "anthropic/claude-opus-4.5": { - ID: "anthropic/claude-opus-4.5", - Name: "Claude Opus 4.5", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 64000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, "anthropic/claude-opus-4.6": { ID: "anthropic/claude-opus-4.6", Name: "Claude Opus 4.6", @@ -78,40 +44,6 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "anthropic/claude-sonnet-4": { - ID: "anthropic/claude-sonnet-4", - Name: "Claude 4 Sonnet", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 64000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "anthropic/claude-sonnet-4.5": { - ID: "anthropic/claude-sonnet-4.5", - Name: "Claude Sonnet 4.5", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 1000000, - MaxOutputTokens: 64000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, "anthropic/claude-sonnet-4.6": { ID: "anthropic/claude-sonnet-4.6", Name: "Claude Sonnet 4.6", @@ -129,57 +61,6 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "deepseek/deepseek-chat-v3-0324": { - ID: "deepseek/deepseek-chat-v3-0324", - Name: "DeepSeek v3 (0324)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 163840, - MaxOutputTokens: 163840, - AvailableTools: []string{ToolFunctionCalling}, - }, - "deepseek/deepseek-chat-v3.1": { - ID: "deepseek/deepseek-chat-v3.1", - Name: "DeepSeek v3.1", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 32768, - MaxOutputTokens: 7168, - AvailableTools: []string{ToolFunctionCalling}, - }, - "deepseek/deepseek-r1": { - ID: "deepseek/deepseek-r1", - Name: "DeepSeek R1 (Original)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 64000, - MaxOutputTokens: 16000, - AvailableTools: []string{ToolFunctionCalling}, - }, "deepseek/deepseek-r1-0528": { ID: "deepseek/deepseek-r1-0528", Name: "DeepSeek R1 (0528)", @@ -197,40 +78,6 @@ var ModelManifest = struct { MaxOutputTokens: 65536, AvailableTools: []string{ToolFunctionCalling}, }, - "deepseek/deepseek-r1-distill-qwen-32b": { - ID: "deepseek/deepseek-r1-distill-qwen-32b", - Name: "DeepSeek R1 (Qwen Distilled)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: false, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 32768, - MaxOutputTokens: 32768, - AvailableTools: []string{}, - }, - "deepseek/deepseek-v3.1-terminus": { - ID: "deepseek/deepseek-v3.1-terminus", - Name: "DeepSeek v3.1 Terminus", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 163840, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, "deepseek/deepseek-v3.2": { ID: "deepseek/deepseek-v3.2", Name: "DeepSeek v3.2", @@ -245,77 +92,9 @@ var ModelManifest = struct { SupportsVideo: false, SupportsPDF: false, ContextWindow: 163840, - MaxOutputTokens: 65536, - AvailableTools: []string{ToolFunctionCalling}, - }, - "google/gemini-2.0-flash-001": { - ID: "google/gemini-2.0-flash-001", - Name: "Gemini 2.0 Flash", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 8192, - AvailableTools: []string{ToolFunctionCalling}, - }, - "google/gemini-2.0-flash-lite-001": { - ID: "google/gemini-2.0-flash-lite-001", - Name: "Gemini 2.0 Flash Lite", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 8192, - AvailableTools: []string{ToolFunctionCalling}, - }, - "google/gemini-2.5-flash": { - ID: "google/gemini-2.5-flash", - Name: "Gemini 2.5 Flash", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 65535, + MaxOutputTokens: 0, AvailableTools: []string{ToolFunctionCalling}, }, - "google/gemini-2.5-flash-image": { - ID: "google/gemini-2.5-flash-image", - Name: "Nano Banana", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: false, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: true, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 32768, - MaxOutputTokens: 32768, - AvailableTools: []string{}, - }, "google/gemini-2.5-flash-lite": { ID: "google/gemini-2.5-flash-lite", Name: "Gemini 2.5 Flash Lite", @@ -367,74 +146,6 @@ var ModelManifest = struct { MaxOutputTokens: 65536, AvailableTools: []string{ToolFunctionCalling}, }, - "google/gemini-3-pro-image-preview": { - ID: "google/gemini-3-pro-image-preview", - Name: "Nano Banana Pro", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: false, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: true, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 65536, - MaxOutputTokens: 32768, - AvailableTools: []string{}, - }, - "google/gemini-3.1-flash-lite-preview": { - ID: "google/gemini-3.1-flash-lite-preview", - Name: "Gemini 3.1 Flash Lite", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 65536, - AvailableTools: []string{ToolFunctionCalling}, - }, - "google/gemini-3.1-pro-preview": { - ID: "google/gemini-3.1-pro-preview", - Name: "Gemini 3.1 Pro", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 65536, - AvailableTools: []string{ToolFunctionCalling}, - }, - "meta-llama/llama-3.3-70b-instruct": { - ID: "meta-llama/llama-3.3-70b-instruct", - Name: "Llama 3.3 70B", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolFunctionCalling}, - }, "meta-llama/llama-4-maverick": { ID: "meta-llama/llama-4-maverick", Name: "Llama 4 Maverick", @@ -452,283 +163,79 @@ var ModelManifest = struct { MaxOutputTokens: 16384, AvailableTools: []string{ToolFunctionCalling}, }, - "meta-llama/llama-4-scout": { - ID: "meta-llama/llama-4-scout", - Name: "Llama 4 Scout", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 327680, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolFunctionCalling}, - }, - "minimax/minimax-m2": { - ID: "minimax/minimax-m2", - Name: "MiniMax M2", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 196608, - MaxOutputTokens: 196608, - AvailableTools: []string{ToolFunctionCalling}, - }, - "minimax/minimax-m2.1": { - ID: "minimax/minimax-m2.1", - Name: "MiniMax M2.1", + "minimax/minimax-m2.7": { + ID: "minimax/minimax-m2.7", + Name: "MiniMax M2.7", Provider: "openrouter", API: "openai-completions", SupportsVision: false, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 196608, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "minimax/minimax-m2.5": { - ID: "minimax/minimax-m2.5", - Name: "MiniMax M2.5", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 196608, - MaxOutputTokens: 196608, - AvailableTools: []string{ToolFunctionCalling}, - }, - "moonshotai/kimi-k2": { - ID: "moonshotai/kimi-k2", - Name: "Kimi K2 (0711)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131000, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "moonshotai/kimi-k2-0905": { - ID: "moonshotai/kimi-k2-0905", - Name: "Kimi K2 (0905)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "moonshotai/kimi-k2.5": { - ID: "moonshotai/kimi-k2.5", - Name: "Kimi K2.5", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 262144, - MaxOutputTokens: 65535, - AvailableTools: []string{ToolFunctionCalling}, - }, - "openai/gpt-4.1": { - ID: "openai/gpt-4.1", - Name: "GPT-4.1", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 1047576, - MaxOutputTokens: 32768, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-4.1-mini": { - ID: "openai/gpt-4.1-mini", - Name: "GPT-4.1 Mini", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 1047576, - MaxOutputTokens: 32768, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-4.1-nano": { - ID: "openai/gpt-4.1-nano", - Name: "GPT-4.1 Nano", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 1047576, - MaxOutputTokens: 32768, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-4o-mini": { - ID: "openai/gpt-4o-mini", - Name: "GPT-4o-mini", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 128000, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-5": { - ID: "openai/gpt-5", - Name: "GPT-5", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-5-image": { - ID: "openai/gpt-5-image", - Name: "GPT ImageGen 1.5", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: true, + SupportsWebSearch: false, + SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + SupportsPDF: false, + ContextWindow: 204800, + MaxOutputTokens: 131072, + AvailableTools: []string{ToolFunctionCalling}, }, - "openai/gpt-5-image-mini": { - ID: "openai/gpt-5-image-mini", - Name: "GPT ImageGen", + "mistralai/devstral-2512": { + ID: "mistralai/devstral-2512", + Name: "Devstral 2", Provider: "openrouter", - API: "responses", - SupportsVision: true, + API: "openai-completions", + SupportsVision: false, SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: true, + SupportsReasoning: false, + SupportsWebSearch: false, + SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + SupportsPDF: false, + ContextWindow: 262144, + MaxOutputTokens: 0, + AvailableTools: []string{ToolFunctionCalling}, }, - "openai/gpt-5-mini": { - ID: "openai/gpt-5-mini", - Name: "GPT-5 mini", + "mistralai/mistral-small-2603": { + ID: "mistralai/mistral-small-2603", + Name: "Mistral Small 4", Provider: "openrouter", - API: "responses", + API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: true, + SupportsWebSearch: false, SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + SupportsPDF: false, + ContextWindow: 262144, + MaxOutputTokens: 0, + AvailableTools: []string{ToolFunctionCalling}, }, - "openai/gpt-5-nano": { - ID: "openai/gpt-5-nano", - Name: "GPT-5 nano", + "moonshotai/kimi-k2.5": { + ID: "moonshotai/kimi-k2.5", + Name: "Kimi K2.5", Provider: "openrouter", - API: "responses", + API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: true, + SupportsWebSearch: false, SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + SupportsPDF: false, + ContextWindow: 262144, + MaxOutputTokens: 65535, + AvailableTools: []string{ToolFunctionCalling}, }, - "openai/gpt-5.1": { - ID: "openai/gpt-5.1", - Name: "GPT-5.1", + "openai/gpt-5-mini": { + ID: "openai/gpt-5-mini", + Name: "GPT-5 mini", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -745,7 +252,7 @@ var ModelManifest = struct { ID: "openai/gpt-5.2", Name: "GPT-5.2", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -758,11 +265,11 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "openai/gpt-5.2-pro": { - ID: "openai/gpt-5.2-pro", - Name: "GPT-5.2 Pro", + "openai/gpt-5.3-codex": { + ID: "openai/gpt-5.3-codex", + Name: "GPT-5.3 Codex", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -775,28 +282,11 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "openai/gpt-5.3-chat": { - ID: "openai/gpt-5.3-chat", - Name: "GPT-5.3 Instant", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 128000, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, "openai/gpt-5.4": { ID: "openai/gpt-5.4", Name: "GPT-5.4", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -809,45 +299,11 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "openai/gpt-oss-120b": { - ID: "openai/gpt-oss-120b", - Name: "GPT OSS 120B", - Provider: "openrouter", - API: "responses", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "openai/gpt-oss-20b": { - ID: "openai/gpt-oss-20b", - Name: "GPT OSS 20B", - Provider: "openrouter", - API: "responses", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "openai/o3": { - ID: "openai/o3", - Name: "o3", + "openai/gpt-5.4-mini": { + ID: "openai/gpt-5.4-mini", + Name: "GPT-5.4 Mini", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -856,32 +312,15 @@ var ModelManifest = struct { SupportsAudio: false, SupportsVideo: false, SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 100000, + ContextWindow: 400000, + MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "openai/o3-mini": { - ID: "openai/o3-mini", - Name: "o3-mini", - Provider: "openrouter", - API: "responses", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 100000, - AvailableTools: []string{ToolFunctionCalling}, - }, - "openai/o3-pro": { - ID: "openai/o3-pro", - Name: "o3 Pro", + "openai/o3": { + ID: "openai/o3", + Name: "o3", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -898,7 +337,7 @@ var ModelManifest = struct { ID: "openai/o4-mini", Name: "o4-mini", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -928,43 +367,9 @@ var ModelManifest = struct { MaxOutputTokens: 0, AvailableTools: []string{}, }, - "qwen/qwen3-235b-a22b": { - ID: "qwen/qwen3-235b-a22b", - Name: "Qwen 3 235B", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 8192, - AvailableTools: []string{ToolFunctionCalling}, - }, - "qwen/qwen3-32b": { - ID: "qwen/qwen3-32b", - Name: "Qwen 3 32B", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 40960, - MaxOutputTokens: 40960, - AvailableTools: []string{ToolFunctionCalling}, - }, - "qwen/qwen3-coder": { - ID: "qwen/qwen3-coder", - Name: "Qwen 3 Coder", + "qwen/qwen3-coder-next": { + ID: "qwen/qwen3-coder-next", + Name: "Qwen 3 Coder Next", Provider: "openrouter", API: "openai-completions", SupportsVision: false, @@ -976,76 +381,42 @@ var ModelManifest = struct { SupportsVideo: false, SupportsPDF: false, ContextWindow: 262144, - MaxOutputTokens: 0, + MaxOutputTokens: 65536, AvailableTools: []string{ToolFunctionCalling}, }, - "x-ai/grok-3": { - ID: "x-ai/grok-3", - Name: "Grok 3", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "x-ai/grok-3-mini": { - ID: "x-ai/grok-3-mini", - Name: "Grok 3 Mini", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "x-ai/grok-4": { - ID: "x-ai/grok-4", - Name: "Grok 4", + "qwen/qwen3.5-flash-02-23": { + ID: "qwen/qwen3.5-flash-02-23", + Name: "Qwen 3.5 Flash", Provider: "openrouter", API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: true, + SupportsWebSearch: false, SupportsImageGen: false, SupportsAudio: false, - SupportsVideo: false, + SupportsVideo: true, SupportsPDF: false, - ContextWindow: 256000, - MaxOutputTokens: 0, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + ContextWindow: 1000000, + MaxOutputTokens: 65536, + AvailableTools: []string{ToolFunctionCalling}, }, - "x-ai/grok-4-fast": { - ID: "x-ai/grok-4-fast", - Name: "Grok 4 Fast", + "qwen/qwen3.5-plus-02-15": { + ID: "qwen/qwen3.5-plus-02-15", + Name: "Qwen 3.5 Plus", Provider: "openrouter", API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: true, + SupportsWebSearch: false, SupportsImageGen: false, SupportsAudio: false, - SupportsVideo: false, + SupportsVideo: true, SupportsPDF: false, - ContextWindow: 2000000, - MaxOutputTokens: 30000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + ContextWindow: 1000000, + MaxOutputTokens: 65536, + AvailableTools: []string{ToolFunctionCalling}, }, "x-ai/grok-4.1-fast": { ID: "x-ai/grok-4.1-fast", @@ -1064,111 +435,43 @@ var ModelManifest = struct { MaxOutputTokens: 30000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "z-ai/glm-4.5": { - ID: "z-ai/glm-4.5", - Name: "GLM 4.5", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 98304, - AvailableTools: []string{ToolFunctionCalling}, - }, - "z-ai/glm-4.5-air": { - ID: "z-ai/glm-4.5-air", - Name: "GLM 4.5 Air", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 98304, - AvailableTools: []string{ToolFunctionCalling}, - }, - "z-ai/glm-4.5v": { - ID: "z-ai/glm-4.5v", - Name: "GLM 4.5V", + "x-ai/grok-4.20-beta": { + ID: "x-ai/grok-4.20-beta", + Name: "Grok 4.20 Beta", Provider: "openrouter", API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 65536, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolFunctionCalling}, - }, - "z-ai/glm-4.6": { - ID: "z-ai/glm-4.6", - Name: "GLM 4.6", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, + SupportsWebSearch: true, SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, SupportsPDF: false, - ContextWindow: 204800, - MaxOutputTokens: 204800, - AvailableTools: []string{ToolFunctionCalling}, - }, - "z-ai/glm-4.6v": { - ID: "z-ai/glm-4.6v", - Name: "GLM 4.6V", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: true, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 131072, - AvailableTools: []string{ToolFunctionCalling}, + ContextWindow: 2000000, + MaxOutputTokens: 0, + AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "z-ai/glm-4.7": { - ID: "z-ai/glm-4.7", - Name: "GLM 4.7", + "x-ai/grok-code-fast-1": { + ID: "x-ai/grok-code-fast-1", + Name: "Grok Code Fast 1", Provider: "openrouter", API: "openai-completions", SupportsVision: false, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: false, + SupportsWebSearch: true, SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, SupportsPDF: false, - ContextWindow: 202752, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, + ContextWindow: 256000, + MaxOutputTokens: 10000, + AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "z-ai/glm-5": { - ID: "z-ai/glm-5", - Name: "GLM 5", + "z-ai/glm-5-turbo": { + ID: "z-ai/glm-5-turbo", + Name: "GLM 5 Turbo", Provider: "openrouter", API: "openai-completions", SupportsVision: false, @@ -1180,14 +483,14 @@ var ModelManifest = struct { SupportsVideo: false, SupportsPDF: false, ContextWindow: 202752, - MaxOutputTokens: 0, + MaxOutputTokens: 131072, AvailableTools: []string{ToolFunctionCalling}, }, }, Aliases: map[string]string{ "beeper/default": "anthropic/claude-opus-4.6", - "beeper/fast": "openai/gpt-5-mini", - "beeper/reasoning": "openai/gpt-5.2", - "beeper/smart": "openai/gpt-5.2", + "beeper/fast": "openai/gpt-5.4-mini", + "beeper/reasoning": "openai/o3", + "beeper/smart": "openai/gpt-5.4", }, } diff --git a/bridges/ai/beeper_models_manifest_test.go b/bridges/ai/beeper_models_manifest_test.go index 47db642c..4d1409a1 100644 --- a/bridges/ai/beeper_models_manifest_test.go +++ b/bridges/ai/beeper_models_manifest_test.go @@ -4,75 +4,34 @@ import "testing" func TestModelManifestMatchesOpenRouterAllowlist(t *testing.T) { expected := map[string]struct{}{ - "google/gemini-3.1-flash-lite-preview": {}, - "openai/gpt-5.3-chat": {}, - "openai/gpt-5.4": {}, - "qwen/qwen2.5-vl-32b-instruct": {}, - "qwen/qwen3-32b": {}, - "qwen/qwen3-235b-a22b": {}, - "qwen/qwen3-coder": {}, - "anthropic/claude-sonnet-4": {}, - "anthropic/claude-sonnet-4.5": {}, - "anthropic/claude-sonnet-4.6": {}, - "anthropic/claude-opus-4.1": {}, - "anthropic/claude-haiku-4.5": {}, - "anthropic/claude-opus-4.5": {}, - "anthropic/claude-opus-4.6": {}, - "deepseek/deepseek-chat-v3-0324": {}, - "deepseek/deepseek-chat-v3.1": {}, - "deepseek/deepseek-v3.1-terminus": {}, - "deepseek/deepseek-v3.2": {}, - "deepseek/deepseek-r1": {}, - "deepseek/deepseek-r1-0528": {}, - "deepseek/deepseek-r1-distill-qwen-32b": {}, - "google/gemini-2.0-flash-001": {}, - "google/gemini-2.5-flash": {}, - "google/gemini-2.5-flash-lite": {}, - "google/gemini-2.5-flash-image": {}, - "google/gemini-2.0-flash-lite-001": {}, - "google/gemini-2.5-pro": {}, - "google/gemini-3.1-pro-preview": {}, - "google/gemini-3-pro-image-preview": {}, - "google/gemini-3-flash-preview": {}, - "meta-llama/llama-3.3-70b-instruct": {}, - "meta-llama/llama-4-scout": {}, - "meta-llama/llama-4-maverick": {}, - "minimax/minimax-m2": {}, - "minimax/minimax-m2.1": {}, - "minimax/minimax-m2.5": {}, - "moonshotai/kimi-k2": {}, - "moonshotai/kimi-k2-0905": {}, - "moonshotai/kimi-k2.5": {}, - "openai/gpt-oss-20b": {}, - "openai/gpt-oss-120b": {}, - "openai/gpt-4o-mini": {}, - "openai/gpt-4.1": {}, - "openai/gpt-4.1-mini": {}, - "openai/gpt-4.1-nano": {}, - "openai/gpt-5": {}, - "openai/gpt-5-mini": {}, - "openai/gpt-5-nano": {}, - "openai/gpt-5.1": {}, - "openai/gpt-5.2": {}, - "openai/gpt-5.2-pro": {}, - "openai/o3-mini": {}, - "openai/o4-mini": {}, - "openai/o3": {}, - "openai/o3-pro": {}, - "openai/gpt-5-image-mini": {}, - "openai/gpt-5-image": {}, - "z-ai/glm-4.5": {}, - "z-ai/glm-4.5v": {}, - "z-ai/glm-4.5-air": {}, - "z-ai/glm-4.6": {}, - "z-ai/glm-4.6v": {}, - "z-ai/glm-4.7": {}, - "z-ai/glm-5": {}, - "x-ai/grok-4": {}, - "x-ai/grok-3": {}, - "x-ai/grok-3-mini": {}, - "x-ai/grok-4-fast": {}, - "x-ai/grok-4.1-fast": {}, + "anthropic/claude-haiku-4.5": {}, + "anthropic/claude-opus-4.6": {}, + "anthropic/claude-sonnet-4.6": {}, + "deepseek/deepseek-r1-0528": {}, + "deepseek/deepseek-v3.2": {}, + "google/gemini-2.5-flash-lite": {}, + "google/gemini-2.5-pro": {}, + "google/gemini-3-flash-preview": {}, + "meta-llama/llama-4-maverick": {}, + "minimax/minimax-m2.7": {}, + "mistralai/devstral-2512": {}, + "mistralai/mistral-small-2603": {}, + "moonshotai/kimi-k2.5": {}, + "openai/gpt-5-mini": {}, + "openai/gpt-5.2": {}, + "openai/gpt-5.3-codex": {}, + "openai/gpt-5.4": {}, + "openai/gpt-5.4-mini": {}, + "openai/o3": {}, + "openai/o4-mini": {}, + "qwen/qwen2.5-vl-32b-instruct": {}, + "qwen/qwen3-coder-next": {}, + "qwen/qwen3.5-flash-02-23": {}, + "qwen/qwen3.5-plus-02-15": {}, + "x-ai/grok-4.1-fast": {}, + "x-ai/grok-4.20-beta": {}, + "x-ai/grok-code-fast-1": {}, + "z-ai/glm-5-turbo": {}, } if len(ModelManifest.Models) != len(expected) { diff --git a/bridges/ai/chat_login_redirect_test.go b/bridges/ai/chat_login_redirect_test.go index 42ea3941..9e473323 100644 --- a/bridges/ai/chat_login_redirect_test.go +++ b/bridges/ai/chat_login_redirect_test.go @@ -140,16 +140,16 @@ func TestModelRedirectTarget(t *testing.T) { } func TestResolveModelIDFromManifestAcceptsRawModelID(t *testing.T) { - const modelID = "google/gemini-2.0-flash-lite-001" + const modelID = "google/gemini-3-flash-preview" if got := resolveModelIDFromManifest(modelID); got != modelID { t.Fatalf("expected raw model ID %q to resolve, got %q", modelID, got) } } func TestResolveModelIDFromManifestAcceptsEncodedModelIDViaCandidates(t *testing.T) { - const encoded = "google%2Fgemini-2.0-flash-lite-001" + const encoded = "google%2Fgemini-3-flash-preview" candidates := candidateModelLookupIDs(encoded) - const canonical = "google/gemini-2.0-flash-lite-001" + const canonical = "google/gemini-3-flash-preview" if !slices.Contains(candidates, canonical) { t.Fatalf("expected decoded model candidate in %#v", candidates) } @@ -169,8 +169,8 @@ func TestCandidateModelLookupIDsRejectsMalformedEncoding(t *testing.T) { } func TestParseModelFromGhostIDAcceptsEscapedGhostID(t *testing.T) { - const ghostID = "model-google%2Fgemini-2.0-flash-lite-001" - const want = "google/gemini-2.0-flash-lite-001" + const ghostID = "model-google%2Fgemini-3-flash-preview" + const want = "google/gemini-3-flash-preview" if got := parseModelFromGhostID(ghostID); got != want { t.Fatalf("expected ghost ID %q to parse to %q, got %q", ghostID, want, got) } @@ -190,8 +190,8 @@ func TestResolveIdentifierAcceptsCanonicalModelIdentifier(t *testing.T) { Metadata: &UserLoginMetadata{ ModelCache: &ModelCache{ Models: []ModelInfo{{ - ID: "openai/gpt-5", - Name: "GPT-5", + ID: "openai/gpt-5.4", + Name: "GPT-5.4", }}, }, }, @@ -200,11 +200,11 @@ func TestResolveIdentifierAcceptsCanonicalModelIdentifier(t *testing.T) { connector: &OpenAIConnector{}, } - resp, err := oc.ResolveIdentifier(context.Background(), "model:openai/gpt-5", false) + resp, err := oc.ResolveIdentifier(context.Background(), "model:openai/gpt-5.4", false) if err != nil { t.Fatalf("ResolveIdentifier returned error: %v", err) } - if resp == nil || resp.UserID != modelUserID("openai/gpt-5") { + if resp == nil || resp.UserID != modelUserID("openai/gpt-5.4") { t.Fatalf("expected canonical model identifier to resolve to model ghost, got %#v", resp) } } diff --git a/bridges/ai/client.go b/bridges/ai/client.go index 842dcac5..1466d972 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -1017,6 +1017,7 @@ func updateGhostLastSync(_ context.Context, ghost *bridgev2.Ghost) bool { func (oc *AIClient) GetCapabilities(ctx context.Context, portal *bridgev2.Portal) *event.RoomFeatures { meta := portalMeta(portal) + isModelRoom := meta != nil && meta.ResolvedTarget != nil && meta.ResolvedTarget.Kind == ResolvedTargetModel // Always recompute effective room capabilities from the resolved room target. modelCaps := oc.getRoomCapabilities(ctx, meta) @@ -1036,11 +1037,17 @@ func (oc *AIClient) GetCapabilities(ctx context.Context, portal *bridgev2.Portal if supportsMsgActions { caps.Reply = event.CapLevelFullySupported - caps.Edit = event.CapLevelFullySupported - caps.EditMaxCount = 10 - caps.EditMaxAge = ptr.Ptr(jsontime.S(AIEditMaxAge)) caps.Reaction = event.CapLevelFullySupported caps.ReactionCount = 1 + if isModelRoom { + caps.Edit = event.CapLevelRejected + caps.EditMaxCount = 0 + caps.EditMaxAge = nil + } else { + caps.Edit = event.CapLevelFullySupported + caps.EditMaxCount = 10 + caps.EditMaxAge = ptr.Ptr(jsontime.S(AIEditMaxAge)) + } } else { // Use explicit rejected levels so features remain visible in // com.beeper.room_features instead of being omitted by omitempty. @@ -1593,8 +1600,8 @@ func resolveModelIDFromManifest(modelID string) string { return "" } -// listAvailableModels fetches models from OpenAI API and caches them -// Returns ModelInfo list from the provider +// listAvailableModels loads models from the derived catalog and caches them. +// The implicit catalog is fed from the OpenRouter-backed manifest. func (oc *AIClient) listAvailableModels(ctx context.Context, forceRefresh bool) ([]ModelInfo, error) { meta := loginMetadata(oc.UserLogin) diff --git a/bridges/ai/client_capabilities_test.go b/bridges/ai/client_capabilities_test.go index f974f9b0..8bd55411 100644 --- a/bridges/ai/client_capabilities_test.go +++ b/bridges/ai/client_capabilities_test.go @@ -29,8 +29,8 @@ func TestGetCapabilities_ModelRoomAllowsReplyEditReaction(t *testing.T) { if caps.Reply != event.CapLevelFullySupported { t.Fatalf("expected reply fully supported in model room, got %v", caps.Reply) } - if caps.Edit != event.CapLevelFullySupported { - t.Fatalf("expected edit fully supported in model room, got %v", caps.Edit) + if caps.Edit != event.CapLevelRejected { + t.Fatalf("expected edit rejected in model room, got %v", caps.Edit) } if caps.Reaction != event.CapLevelFullySupported { t.Fatalf("expected reaction fully supported in model room, got %v", caps.Reaction) @@ -40,8 +40,12 @@ func TestGetCapabilities_ModelRoomAllowsReplyEditReaction(t *testing.T) { if err != nil { t.Fatalf("failed to marshal room features: %v", err) } - if !strings.Contains(string(raw), `"reaction":2`) { - t.Fatalf("expected serialized room features to contain reaction=2, got: %s", string(raw)) + rawJSON := string(raw) + if !strings.Contains(rawJSON, `"reaction":2`) { + t.Fatalf("expected serialized room features to contain reaction=2, got: %s", rawJSON) + } + if !strings.Contains(rawJSON, `"edit":-2`) { + t.Fatalf("expected serialized room features to contain edit=-2, got: %s", rawJSON) } } diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index 1a9d6429..4604d084 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -323,15 +323,17 @@ func (oc *AIClient) HandleMatrixTyping(ctx context.Context, typing *bridgev2.Mat // HandleMatrixEdit handles edits to previously sent messages func (oc *AIClient) HandleMatrixEdit(ctx context.Context, edit *bridgev2.MatrixEdit) error { - if edit.Content == nil || edit.EditTarget == nil { - return errors.New("invalid edit: missing content or target") - } - portal := edit.Portal if portal == nil { return errors.New("portal is nil") } meta := portalMeta(portal) + if meta != nil && meta.ResolvedTarget != nil && meta.ResolvedTarget.Kind == ResolvedTargetModel { + return bridgev2.ErrEditsNotSupportedInPortal + } + if edit.Content == nil || edit.EditTarget == nil { + return errors.New("invalid edit: missing content or target") + } // Get the new message body newBody := strings.TrimSpace(edit.Content.Body) diff --git a/bridges/ai/handlematrix_edit_test.go b/bridges/ai/handlematrix_edit_test.go new file mode 100644 index 00000000..d71dc0b2 --- /dev/null +++ b/bridges/ai/handlematrix_edit_test.go @@ -0,0 +1,61 @@ +package ai + +import ( + "context" + "testing" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/event" +) + +func TestHandleMatrixEdit_ModelRoomRejectsEdits(t *testing.T) { + oc := &AIClient{} + edit := &bridgev2.MatrixEdit{ + MatrixEventBase: bridgev2.MatrixEventBase[*event.MessageEventContent]{ + Portal: &bridgev2.Portal{ + Portal: &database.Portal{ + OtherUserID: modelUserID("openai/gpt-5"), + Metadata: modelModeTestMeta("openai/gpt-5"), + }, + }, + Content: &event.MessageEventContent{Body: "updated"}, + }, + EditTarget: &database.Message{}, + } + + err := oc.HandleMatrixEdit(context.Background(), edit) + if err == nil { + t.Fatal("expected model room edit to be rejected") + } + if err.Error() != bridgev2.ErrEditsNotSupportedInPortal.Error() { + t.Fatalf("expected ErrEditsNotSupportedInPortal, got %v", err) + } +} + +func TestHandleMatrixEdit_AgentRoomStillUsesAgentPath(t *testing.T) { + oc := &AIClient{} + edit := &bridgev2.MatrixEdit{ + MatrixEventBase: bridgev2.MatrixEventBase[*event.MessageEventContent]{ + Portal: &bridgev2.Portal{ + Portal: &database.Portal{ + OtherUserID: agentUserID("beeper"), + Metadata: agentModeTestMeta("beeper"), + }, + }, + Content: &event.MessageEventContent{Body: " "}, + }, + EditTarget: &database.Message{}, + } + + err := oc.HandleMatrixEdit(context.Background(), edit) + if err == nil { + t.Fatal("expected agent edit to continue into the existing handler path") + } + if err.Error() == bridgev2.ErrEditsNotSupportedInPortal.Error() { + t.Fatalf("expected agent room edit to avoid model-room rejection, got %v", err) + } + if err.Error() != "empty edit body" { + t.Fatalf("expected empty edit body error from existing path, got %v", err) + } +} diff --git a/bridges/ai/image_generation_tool.go b/bridges/ai/image_generation_tool.go index 70314fe3..1df8a926 100644 --- a/bridges/ai/image_generation_tool.go +++ b/bridges/ai/image_generation_tool.go @@ -176,12 +176,35 @@ func resolveImageGenProvider(req imageGenRequest, btc *BridgeToolContext) (image } } + loginMeta := loginMetadata(btc.Client.UserLogin) + inferredProvider := inferProviderFromModel(req.Model) + if inferredProvider != "" { + switch inferredProvider { + case imageGenProviderOpenAI: + if supportsOpenAIImageGen(btc) { + return imageGenProviderOpenAI, nil + } + case imageGenProviderGemini: + if supportsGeminiImageGen(btc) { + return imageGenProviderGemini, nil + } + case imageGenProviderOpenRouter: + if supportsOpenRouterImageGen(btc) { + return imageGenProviderOpenRouter, nil + } + } + // Magic Proxy only exposes the OpenAI images route in practice, so use + // that when a requested image model belongs to an unavailable surface. + if loginMeta != nil && loginMeta.Provider == ProviderMagicProxy && supportsOpenAIImageGen(btc) { + return imageGenProviderOpenAI, nil + } + } + // Prefer OpenRouter image gen whenever it's available (Gemini models support extra controls). if supportsOpenRouterImageGen(btc) { return imageGenProviderOpenRouter, nil } - loginMeta := loginMetadata(btc.Client.UserLogin) switch loginMeta.Provider { case ProviderOpenAI: if !supportsOpenAIImageGen(btc) { @@ -267,12 +290,20 @@ func normalizeOpenAIModel(model string) string { if model == "" { return defaultOpenAIImageModel } + switch inferProviderFromModel(model) { + case imageGenProviderGemini, imageGenProviderOpenRouter: + return defaultOpenAIImageModel + } _, actual := ParseModelPrefix(model) actual = strings.TrimSpace(actual) actual = strings.TrimPrefix(actual, "openai/") if actual == "" { return model } + switch strings.ToLower(actual) { + case "gpt-5-image", "gpt-5-image-mini": + return defaultOpenAIImageModel + } return actual } diff --git a/bridges/ai/image_generation_tool_magic_proxy_test.go b/bridges/ai/image_generation_tool_magic_proxy_test.go index 2c1d3864..ef8ffdcf 100644 --- a/bridges/ai/image_generation_tool_magic_proxy_test.go +++ b/bridges/ai/image_generation_tool_magic_proxy_test.go @@ -69,6 +69,29 @@ func TestResolveImageGenProviderMagicProxyProviderOpenAIUsesOpenAI(t *testing.T) } } +func TestResolveImageGenProviderMagicProxyModelHintFallsBackToOpenAI(t *testing.T) { + meta := &UserLoginMetadata{ + Provider: ProviderMagicProxy, + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, + } + btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) + + got, err := resolveImageGenProvider(imageGenRequest{ + Model: "google/gemini-3-pro-image-preview", + Prompt: "cat", + Count: 1, + }, btc) + if err != nil { + t.Fatalf("resolveImageGenProvider returned error: %v", err) + } + if got != imageGenProviderOpenAI { + t.Fatalf("expected provider %q, got %q", imageGenProviderOpenAI, got) + } +} + func TestResolveImageGenProviderMagicProxyProviderGeminiIsUnavailable(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, @@ -89,6 +112,27 @@ func TestResolveImageGenProviderMagicProxyProviderGeminiIsUnavailable(t *testing } } +func TestNormalizeOpenAIModelMapsUnavailableAliasesToGPTImage1(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {name: "empty", input: "", want: "gpt-image-1"}, + {name: "prefixed alias", input: "openai/gpt-5-image", want: "gpt-image-1"}, + {name: "mini alias", input: "gpt-5-image-mini", want: "gpt-image-1"}, + {name: "gemini alias", input: "google/gemini-3-pro-image-preview", want: "gpt-image-1"}, + {name: "native openai", input: "gpt-image-1", want: "gpt-image-1"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := normalizeOpenAIModel(tc.input); got != tc.want { + t.Fatalf("normalizeOpenAIModel(%q) = %q, want %q", tc.input, got, tc.want) + } + }) + } +} + func TestBuildOpenAIImagesBaseURLMagicProxy(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, diff --git a/bridges/ai/model_catalog_test.go b/bridges/ai/model_catalog_test.go index 02e21fd0..fc3d56df 100644 --- a/bridges/ai/model_catalog_test.go +++ b/bridges/ai/model_catalog_test.go @@ -20,3 +20,40 @@ func TestImplicitModelCatalogEntries_MagicProxySeedsCatalog(t *testing.T) { t.Fatalf("expected non-empty model catalog entries for magic_proxy, got 0") } } + +func TestImplicitModelCatalogEntries_OpenAILoginUsesManifestMetadata(t *testing.T) { + oc := &AIClient{ + connector: &OpenAIConnector{}, + } + meta := &UserLoginMetadata{ + Provider: ProviderOpenAI, + Credentials: &LoginCredentials{ + APIKey: "openai-token", + }, + } + + entries := oc.implicitModelCatalogEntries(meta) + if len(entries) == 0 { + t.Fatalf("expected non-empty model catalog entries for openai, got 0") + } + + entry := findModelCatalogEntry(entries, ProviderOpenAI, "gpt-5.4-mini") + if entry == nil { + t.Fatal("expected gpt-5.4-mini entry in openai catalog") + } + + manifestInfo, ok := ModelManifest.Models["openai/gpt-5.4-mini"] + if !ok { + t.Fatal("expected gpt-5.4-mini in manifest") + } + + if entry.ContextWindow != manifestInfo.ContextWindow { + t.Fatalf("context window = %d, want %d", entry.ContextWindow, manifestInfo.ContextWindow) + } + if entry.MaxOutputTokens != manifestInfo.MaxOutputTokens { + t.Fatalf("max output tokens = %d, want %d", entry.MaxOutputTokens, manifestInfo.MaxOutputTokens) + } + if !catalogInputIncludes(entry, "image") || !catalogInputIncludes(entry, "pdf") { + t.Fatalf("expected openai catalog entry to retain manifest modalities, got %#v", entry.Input) + } +} diff --git a/cmd/generate-models/main.go b/cmd/generate-models/main.go index 3626de85..76f78d58 100644 --- a/cmd/generate-models/main.go +++ b/cmd/generate-models/main.go @@ -26,104 +26,43 @@ var modelConfig = struct { Aliases map[string]string }{ Models: map[string]string{ - // Anthropic (Claude) via OpenRouter - "anthropic/claude-haiku-4.5": "Claude Haiku 4.5", - "anthropic/claude-opus-4.1": "Claude 4.1 Opus", - "anthropic/claude-opus-4.5": "Claude Opus 4.5", - "anthropic/claude-opus-4.6": "Claude Opus 4.6", - "anthropic/claude-sonnet-4": "Claude 4 Sonnet", - "anthropic/claude-sonnet-4.5": "Claude Sonnet 4.5", - "anthropic/claude-sonnet-4.6": "Claude Sonnet 4.6", - - // DeepSeek - "deepseek/deepseek-chat-v3-0324": "DeepSeek v3 (0324)", - "deepseek/deepseek-chat-v3.1": "DeepSeek v3.1", - "deepseek/deepseek-v3.1-terminus": "DeepSeek v3.1 Terminus", - "deepseek/deepseek-v3.2": "DeepSeek v3.2", - "deepseek/deepseek-r1": "DeepSeek R1 (Original)", - "deepseek/deepseek-r1-0528": "DeepSeek R1 (0528)", - "deepseek/deepseek-r1-distill-qwen-32b": "DeepSeek R1 (Qwen Distilled)", - - // Gemini (Google) via OpenRouter - "google/gemini-2.0-flash-001": "Gemini 2.0 Flash", - "google/gemini-2.0-flash-lite-001": "Gemini 2.0 Flash Lite", - "google/gemini-2.5-flash": "Gemini 2.5 Flash", - "google/gemini-2.5-flash-image": "Nano Banana", - "google/gemini-2.5-flash-lite": "Gemini 2.5 Flash Lite", - "google/gemini-2.5-pro": "Gemini 2.5 Pro", - "google/gemini-3-flash-preview": "Gemini 3 Flash", - "google/gemini-3-pro-image-preview": "Nano Banana Pro", - "google/gemini-3.1-flash-lite-preview": "Gemini 3.1 Flash Lite", - "google/gemini-3.1-pro-preview": "Gemini 3.1 Pro", - // For Gemini image generation, use Nano Banana / Nano Banana Pro. - - // GLM (Z.AI) - "z-ai/glm-4.5": "GLM 4.5", - "z-ai/glm-4.5-air": "GLM 4.5 Air", - "z-ai/glm-4.5v": "GLM 4.5V", - "z-ai/glm-4.6": "GLM 4.6", - "z-ai/glm-4.6v": "GLM 4.6V", - "z-ai/glm-4.7": "GLM 4.7", - "z-ai/glm-5": "GLM 5", - - // Kimi (Moonshot) - "moonshotai/kimi-k2": "Kimi K2 (0711)", - "moonshotai/kimi-k2-0905": "Kimi K2 (0905)", - "moonshotai/kimi-k2.5": "Kimi K2.5", - - // Llama (Meta) - "meta-llama/llama-3.3-70b-instruct": "Llama 3.3 70B", - "meta-llama/llama-4-maverick": "Llama 4 Maverick", - "meta-llama/llama-4-scout": "Llama 4 Scout", - - // MiniMax - "minimax/minimax-m2": "MiniMax M2", - "minimax/minimax-m2.1": "MiniMax M2.1", - "minimax/minimax-m2.5": "MiniMax M2.5", - - // OpenAI models via OpenRouter - "openai/gpt-4.1": "GPT-4.1", - "openai/gpt-4.1-mini": "GPT-4.1 Mini", - "openai/gpt-4.1-nano": "GPT-4.1 Nano", - "openai/gpt-4o-mini": "GPT-4o-mini", - "openai/gpt-5": "GPT-5", - "openai/gpt-5-image": "GPT ImageGen 1.5", - "openai/gpt-5-image-mini": "GPT ImageGen", - "openai/gpt-5-mini": "GPT-5 mini", - "openai/gpt-5-nano": "GPT-5 nano", - "openai/gpt-5.1": "GPT-5.1", - "openai/gpt-5.2": "GPT-5.2", - "openai/gpt-5.2-pro": "GPT-5.2 Pro", - "openai/gpt-5.3-chat": "GPT-5.3 Instant", - "openai/gpt-5.4": "GPT-5.4", - "openai/gpt-oss-20b": "GPT OSS 20B", - "openai/gpt-oss-120b": "GPT OSS 120B", - "openai/o3": "o3", - "openai/o3-mini": "o3-mini", - "openai/o3-pro": "o3 Pro", - "openai/o4-mini": "o4-mini", - - // Qwen (Alibaba) - "qwen/qwen2.5-vl-32b-instruct": "Qwen 2.5 32B", - "qwen/qwen3-32b": "Qwen 3 32B", - "qwen/qwen3-235b-a22b": "Qwen 3 235B", - "qwen/qwen3-coder": "Qwen 3 Coder", - - // xAI (Grok) - "x-ai/grok-3": "Grok 3", - "x-ai/grok-3-mini": "Grok 3 Mini", - "x-ai/grok-4": "Grok 4", - "x-ai/grok-4-fast": "Grok 4 Fast", - "x-ai/grok-4.1-fast": "Grok 4.1 Fast", + "anthropic/claude-haiku-4.5": "Claude Haiku 4.5", + "anthropic/claude-opus-4.6": "Claude Opus 4.6", + "anthropic/claude-sonnet-4.6": "Claude Sonnet 4.6", + "deepseek/deepseek-r1-0528": "DeepSeek R1 (0528)", + "deepseek/deepseek-v3.2": "DeepSeek v3.2", + "google/gemini-2.5-flash-lite": "Gemini 2.5 Flash Lite", + "google/gemini-2.5-pro": "Gemini 2.5 Pro", + "google/gemini-3-flash-preview": "Gemini 3 Flash", + "meta-llama/llama-4-maverick": "Llama 4 Maverick", + "minimax/minimax-m2.7": "MiniMax M2.7", + "mistralai/devstral-2512": "Devstral 2", + "mistralai/mistral-small-2603": "Mistral Small 4", + "moonshotai/kimi-k2.5": "Kimi K2.5", + "openai/gpt-5-mini": "GPT-5 mini", + "openai/gpt-5.2": "GPT-5.2", + "openai/gpt-5.3-codex": "GPT-5.3 Codex", + "openai/gpt-5.4": "GPT-5.4", + "openai/gpt-5.4-mini": "GPT-5.4 Mini", + "openai/o3": "o3", + "openai/o4-mini": "o4-mini", + "qwen/qwen2.5-vl-32b-instruct": "Qwen 2.5 32B", + "qwen/qwen3-coder-next": "Qwen 3 Coder Next", + "qwen/qwen3.5-flash-02-23": "Qwen 3.5 Flash", + "qwen/qwen3.5-plus-02-15": "Qwen 3.5 Plus", + "x-ai/grok-4.1-fast": "Grok 4.1 Fast", + "x-ai/grok-4.20-beta": "Grok 4.20 Beta", + "x-ai/grok-code-fast-1": "Grok Code Fast 1", + "z-ai/glm-5-turbo": "GLM 5 Turbo", }, Aliases: map[string]string{ // Default alias "beeper/default": "anthropic/claude-opus-4.6", // Stable aliases that can be remapped - "beeper/fast": "openai/gpt-5-mini", - "beeper/smart": "openai/gpt-5.2", - "beeper/reasoning": "openai/gpt-5.2", // Uses reasoning effort parameter + "beeper/fast": "openai/gpt-5.4-mini", + "beeper/smart": "openai/gpt-5.4", + "beeper/reasoning": "openai/o3", }, } @@ -188,15 +127,11 @@ func main() { } func run() error { - token := flag.String("openrouter-token", "", "OpenRouter API token") + token := flag.String("openrouter-token", "", "Optional OpenRouter API token") outputFile := flag.String("output", "bridges/ai/beeper_models_generated.go", "Output Go file") jsonFile := flag.String("json", "pkg/ai/beeper_models.json", "Output JSON file for clients") flag.Parse() - if *token == "" { - return fmt.Errorf("--openrouter-token is required") - } - models, err := fetchOpenRouterModels(*token) if err != nil { return fmt.Errorf("fetching models: %w", err) @@ -219,7 +154,9 @@ func fetchOpenRouterModels(token string) (map[string]OpenRouterModel, error) { if err != nil { return nil, err } - req.Header.Set("Authorization", "Bearer "+token) + if strings.TrimSpace(token) != "" { + req.Header.Set("Authorization", "Bearer "+token) + } client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) diff --git a/generate-models.sh b/generate-models.sh index 9503cc5e..c1400be4 100755 --- a/generate-models.sh +++ b/generate-models.sh @@ -1,6 +1,6 @@ #!/bin/bash # Generate AI models Go file from OpenRouter API -# Usage: ./generate-models.sh --openrouter-token="YOUR_TOKEN" +# Usage: ./generate-models.sh [--openrouter-token="YOUR_TOKEN"] # # This script fetches model capabilities from OpenRouter and generates # a Go file with model definitions. The generated file is checked into @@ -24,10 +24,10 @@ while [[ $# -gt 0 ]]; do shift ;; -h|--help) - echo "Usage: $0 --openrouter-token=TOKEN [--output=FILE]" + echo "Usage: $0 [--openrouter-token=TOKEN] [--output=FILE]" echo "" echo "Options:" - echo " --openrouter-token=TOKEN OpenRouter API token (required)" + echo " --openrouter-token=TOKEN Optional OpenRouter API token" echo " --output=FILE Output file path (default: bridges/ai/beeper_models_generated.go)" echo " --json=FILE Output JSON path (default: pkg/ai/beeper_models.json)" exit 0 @@ -43,12 +43,6 @@ while [[ $# -gt 0 ]]; do esac done -if [ -z "$OPENROUTER_TOKEN" ]; then - echo "Error: --openrouter-token is required" - echo "Usage: $0 --openrouter-token=TOKEN" - exit 1 -fi - # Change to script directory cd "$(dirname "$0")" diff --git a/pkg/agents/presets.go b/pkg/agents/presets.go index 804074f9..354f37a7 100644 --- a/pkg/agents/presets.go +++ b/pkg/agents/presets.go @@ -4,10 +4,10 @@ import "slices" // Model constants for preset agents (aligned with clawdbot recommended models). const ( - ModelClaudeSonnet = "anthropic/claude-sonnet-4.5" + ModelClaudeSonnet = "anthropic/claude-sonnet-4.6" ModelClaudeOpus = "anthropic/claude-opus-4.6" ModelOpenAIGPT52 = "openai/gpt-5.2" - ModelZAIGLM47 = "z-ai/glm-4.7" + ModelZAIGLM47 = "z-ai/glm-5-turbo" ) // PresetAgents contains the default agent definitions: diff --git a/pkg/ai/beeper_models.json b/pkg/ai/beeper_models.json index 57904ff2..2ac68569 100644 --- a/pkg/ai/beeper_models.json +++ b/pkg/ai/beeper_models.json @@ -18,627 +18,16 @@ ] }, { - "id": "anthropic/claude-opus-4.1", - "name": "Claude 4.1 Opus", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 32000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "anthropic/claude-opus-4.5", - "name": "Claude Opus 4.5", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 64000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "anthropic/claude-opus-4.6", - "name": "Claude Opus 4.6", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1000000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "anthropic/claude-sonnet-4", - "name": "Claude 4 Sonnet", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 64000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "anthropic/claude-sonnet-4.5", - "name": "Claude Sonnet 4.5", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1000000, - "max_output_tokens": 64000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "anthropic/claude-sonnet-4.6", - "name": "Claude Sonnet 4.6", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1000000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "deepseek/deepseek-chat-v3-0324", - "name": "DeepSeek v3 (0324)", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 163840, - "max_output_tokens": 163840, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "deepseek/deepseek-chat-v3.1", - "name": "DeepSeek v3.1", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 32768, - "max_output_tokens": 7168, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "deepseek/deepseek-r1", - "name": "DeepSeek R1 (Original)", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 64000, - "max_output_tokens": 16000, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "deepseek/deepseek-r1-0528", - "name": "DeepSeek R1 (0528)", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 163840, - "max_output_tokens": 65536, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "deepseek/deepseek-r1-distill-qwen-32b", - "name": "DeepSeek R1 (Qwen Distilled)", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": false, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 32768, - "max_output_tokens": 32768 - }, - { - "id": "deepseek/deepseek-v3.1-terminus", - "name": "DeepSeek v3.1 Terminus", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 163840, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "deepseek/deepseek-v3.2", - "name": "DeepSeek v3.2", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 163840, - "max_output_tokens": 65536, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "google/gemini-2.0-flash-001", - "name": "Gemini 2.0 Flash", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 8192, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "google/gemini-2.0-flash-lite-001", - "name": "Gemini 2.0 Flash Lite", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 8192, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "google/gemini-2.5-flash", - "name": "Gemini 2.5 Flash", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65535, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "google/gemini-2.5-flash-image", - "name": "Nano Banana", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": false, - "supports_reasoning": false, - "supports_web_search": false, - "supports_image_gen": true, - "supports_pdf": true, - "context_window": 32768, - "max_output_tokens": 32768 - }, - { - "id": "google/gemini-2.5-flash-lite", - "name": "Gemini 2.5 Flash Lite", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65535, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "google/gemini-2.5-pro", - "name": "Gemini 2.5 Pro", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65536, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "google/gemini-3-flash-preview", - "name": "Gemini 3 Flash", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65536, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "google/gemini-3-pro-image-preview", - "name": "Nano Banana Pro", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": false, - "supports_reasoning": true, - "supports_web_search": false, - "supports_image_gen": true, - "supports_pdf": true, - "context_window": 65536, - "max_output_tokens": 32768 - }, - { - "id": "google/gemini-3.1-flash-lite-preview", - "name": "Gemini 3.1 Flash Lite", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65536, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "google/gemini-3.1-pro-preview", - "name": "Gemini 3.1 Pro", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65536, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "meta-llama/llama-3.3-70b-instruct", - "name": "Llama 3.3 70B", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 131072, - "max_output_tokens": 16384, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "meta-llama/llama-4-maverick", - "name": "Llama 4 Maverick", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 1048576, - "max_output_tokens": 16384, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "meta-llama/llama-4-scout", - "name": "Llama 4 Scout", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 327680, - "max_output_tokens": 16384, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "minimax/minimax-m2", - "name": "MiniMax M2", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 196608, - "max_output_tokens": 196608, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "minimax/minimax-m2.1", - "name": "MiniMax M2.1", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 196608, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "minimax/minimax-m2.5", - "name": "MiniMax M2.5", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 196608, - "max_output_tokens": 196608, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "moonshotai/kimi-k2", - "name": "Kimi K2 (0711)", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 131000, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "moonshotai/kimi-k2-0905", - "name": "Kimi K2 (0905)", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 131072, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "moonshotai/kimi-k2.5", - "name": "Kimi K2.5", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 262144, - "max_output_tokens": 65535, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "openai/gpt-4.1", - "name": "GPT-4.1", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1047576, - "max_output_tokens": 32768, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-4.1-mini", - "name": "GPT-4.1 Mini", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1047576, - "max_output_tokens": 32768, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-4.1-nano", - "name": "GPT-4.1 Nano", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1047576, - "max_output_tokens": 32768, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-4o-mini", - "name": "GPT-4o-mini", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 128000, - "max_output_tokens": 16384, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5", - "name": "GPT-5", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5-image", - "name": "GPT ImageGen 1.5", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_image_gen": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5-image-mini", - "name": "GPT ImageGen", + "id": "anthropic/claude-opus-4.6", + "name": "Claude Opus 4.6", "provider": "openrouter", - "api": "responses", + "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": true, - "supports_image_gen": true, "supports_pdf": true, - "context_window": 400000, + "context_window": 1000000, "max_output_tokens": 128000, "available_tools": [ "web_search", @@ -646,16 +35,16 @@ ] }, { - "id": "openai/gpt-5-mini", - "name": "GPT-5 mini", + "id": "anthropic/claude-sonnet-4.6", + "name": "Claude Sonnet 4.6", "provider": "openrouter", - "api": "responses", + "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": true, "supports_pdf": true, - "context_window": 400000, + "context_window": 1000000, "max_output_tokens": 128000, "available_tools": [ "web_search", @@ -663,427 +52,388 @@ ] }, { - "id": "openai/gpt-5-nano", - "name": "GPT-5 nano", + "id": "deepseek/deepseek-r1-0528", + "name": "DeepSeek R1 (0528)", "provider": "openrouter", - "api": "responses", - "supports_vision": true, + "api": "openai-completions", + "supports_vision": false, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, + "supports_web_search": false, + "context_window": 163840, + "max_output_tokens": 65536, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "openai/gpt-5.1", - "name": "GPT-5.1", + "id": "deepseek/deepseek-v3.2", + "name": "DeepSeek v3.2", "provider": "openrouter", - "api": "responses", - "supports_vision": true, + "api": "openai-completions", + "supports_vision": false, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, + "supports_web_search": false, + "context_window": 163840, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "openai/gpt-5.2", - "name": "GPT-5.2", + "id": "google/gemini-2.5-flash-lite", + "name": "Gemini 2.5 Flash Lite", "provider": "openrouter", - "api": "responses", + "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": true, + "supports_web_search": false, + "supports_audio": true, + "supports_video": true, "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, + "context_window": 1048576, + "max_output_tokens": 65535, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "openai/gpt-5.2-pro", - "name": "GPT-5.2 Pro", + "id": "google/gemini-2.5-pro", + "name": "Gemini 2.5 Pro", "provider": "openrouter", - "api": "responses", + "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": true, + "supports_web_search": false, + "supports_audio": true, + "supports_video": true, "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, + "context_window": 1048576, + "max_output_tokens": 65536, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "openai/gpt-5.3-chat", - "name": "GPT-5.3 Instant", + "id": "google/gemini-3-flash-preview", + "name": "Gemini 3 Flash", "provider": "openrouter", - "api": "responses", + "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, + "supports_reasoning": true, + "supports_web_search": false, + "supports_audio": true, + "supports_video": true, "supports_pdf": true, - "context_window": 128000, - "max_output_tokens": 16384, + "context_window": 1048576, + "max_output_tokens": 65536, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "openai/gpt-5.4", - "name": "GPT-5.4", + "id": "meta-llama/llama-4-maverick", + "name": "Llama 4 Maverick", "provider": "openrouter", - "api": "responses", + "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1050000, - "max_output_tokens": 128000, + "supports_reasoning": false, + "supports_web_search": false, + "context_window": 1048576, + "max_output_tokens": 16384, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "openai/gpt-oss-120b", - "name": "GPT OSS 120B", + "id": "minimax/minimax-m2.7", + "name": "MiniMax M2.7", "provider": "openrouter", - "api": "responses", + "api": "openai-completions", "supports_vision": false, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": false, - "context_window": 131072, + "context_window": 204800, + "max_output_tokens": 131072, "available_tools": [ "function_calling" ] }, { - "id": "openai/gpt-oss-20b", - "name": "GPT OSS 20B", + "id": "mistralai/devstral-2512", + "name": "Devstral 2", "provider": "openrouter", - "api": "responses", + "api": "openai-completions", "supports_vision": false, "supports_tool_calling": true, - "supports_reasoning": true, + "supports_reasoning": false, "supports_web_search": false, - "context_window": 131072, + "context_window": 262144, "available_tools": [ "function_calling" ] }, { - "id": "openai/o3", - "name": "o3", + "id": "mistralai/mistral-small-2603", + "name": "Mistral Small 4", "provider": "openrouter", - "api": "responses", + "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 100000, + "supports_web_search": false, + "context_window": 262144, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "openai/o3-mini", - "name": "o3-mini", + "id": "moonshotai/kimi-k2.5", + "name": "Kimi K2.5", "provider": "openrouter", - "api": "responses", - "supports_vision": false, + "api": "openai-completions", + "supports_vision": true, "supports_tool_calling": true, - "supports_reasoning": false, + "supports_reasoning": true, "supports_web_search": false, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 100000, + "context_window": 262144, + "max_output_tokens": 65535, "available_tools": [ "function_calling" ] }, { - "id": "openai/o3-pro", - "name": "o3 Pro", + "id": "openai/gpt-5-mini", + "name": "GPT-5 mini", "provider": "openrouter", - "api": "responses", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": true, "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 100000, + "context_window": 400000, + "max_output_tokens": 128000, "available_tools": [ "web_search", "function_calling" ] }, { - "id": "openai/o4-mini", - "name": "o4-mini", + "id": "openai/gpt-5.2", + "name": "GPT-5.2", "provider": "openrouter", - "api": "responses", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": true, "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 100000, + "context_window": 400000, + "max_output_tokens": 128000, "available_tools": [ "web_search", "function_calling" ] }, { - "id": "qwen/qwen2.5-vl-32b-instruct", - "name": "Qwen 2.5 32B", + "id": "openai/gpt-5.3-codex", + "name": "GPT-5.3 Codex", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, - "supports_tool_calling": false, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 128000 - }, - { - "id": "qwen/qwen3-235b-a22b", - "name": "Qwen 3 235B", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "context_window": 131072, - "max_output_tokens": 8192, + "supports_web_search": true, + "supports_pdf": true, + "context_window": 400000, + "max_output_tokens": 128000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "qwen/qwen3-32b", - "name": "Qwen 3 32B", + "id": "openai/gpt-5.4", + "name": "GPT-5.4", "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, + "api": "openai-responses", + "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "context_window": 40960, - "max_output_tokens": 40960, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "qwen/qwen3-coder", - "name": "Qwen 3 Coder", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 262144, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "x-ai/grok-3", - "name": "Grok 3", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, "supports_web_search": true, - "context_window": 131072, + "supports_pdf": true, + "context_window": 1050000, + "max_output_tokens": 128000, "available_tools": [ "web_search", "function_calling" ] }, { - "id": "x-ai/grok-3-mini", - "name": "Grok 3 Mini", + "id": "openai/gpt-5.4-mini", + "name": "GPT-5.4 Mini", "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, + "api": "openai-responses", + "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": true, - "context_window": 131072, + "supports_pdf": true, + "context_window": 400000, + "max_output_tokens": 128000, "available_tools": [ "web_search", "function_calling" ] }, { - "id": "x-ai/grok-4", - "name": "Grok 4", + "id": "openai/o3", + "name": "o3", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": true, - "context_window": 256000, + "supports_pdf": true, + "context_window": 200000, + "max_output_tokens": 100000, "available_tools": [ "web_search", "function_calling" ] }, { - "id": "x-ai/grok-4-fast", - "name": "Grok 4 Fast", + "id": "openai/o4-mini", + "name": "o4-mini", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": true, - "context_window": 2000000, - "max_output_tokens": 30000, + "supports_pdf": true, + "context_window": 200000, + "max_output_tokens": 100000, "available_tools": [ "web_search", "function_calling" ] }, { - "id": "x-ai/grok-4.1-fast", - "name": "Grok 4.1 Fast", + "id": "qwen/qwen2.5-vl-32b-instruct", + "name": "Qwen 2.5 32B", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "context_window": 2000000, - "max_output_tokens": 30000, - "available_tools": [ - "web_search", - "function_calling" - ] + "supports_tool_calling": false, + "supports_reasoning": false, + "supports_web_search": false, + "context_window": 128000 }, { - "id": "z-ai/glm-4.5", - "name": "GLM 4.5", + "id": "qwen/qwen3-coder-next", + "name": "Qwen 3 Coder Next", "provider": "openrouter", "api": "openai-completions", "supports_vision": false, "supports_tool_calling": true, - "supports_reasoning": true, + "supports_reasoning": false, "supports_web_search": false, - "context_window": 131072, - "max_output_tokens": 98304, + "context_window": 262144, + "max_output_tokens": 65536, "available_tools": [ "function_calling" ] }, { - "id": "z-ai/glm-4.5-air", - "name": "GLM 4.5 Air", + "id": "qwen/qwen3.5-flash-02-23", + "name": "Qwen 3.5 Flash", "provider": "openrouter", "api": "openai-completions", - "supports_vision": false, + "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": false, - "context_window": 131072, - "max_output_tokens": 98304, + "supports_video": true, + "context_window": 1000000, + "max_output_tokens": 65536, "available_tools": [ "function_calling" ] }, { - "id": "z-ai/glm-4.5v", - "name": "GLM 4.5V", + "id": "qwen/qwen3.5-plus-02-15", + "name": "Qwen 3.5 Plus", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": false, - "context_window": 65536, - "max_output_tokens": 16384, + "supports_video": true, + "context_window": 1000000, + "max_output_tokens": 65536, "available_tools": [ "function_calling" ] }, { - "id": "z-ai/glm-4.6", - "name": "GLM 4.6", + "id": "x-ai/grok-4.1-fast", + "name": "Grok 4.1 Fast", "provider": "openrouter", "api": "openai-completions", - "supports_vision": false, + "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "context_window": 204800, - "max_output_tokens": 204800, + "supports_web_search": true, + "context_window": 2000000, + "max_output_tokens": 30000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "z-ai/glm-4.6v", - "name": "GLM 4.6V", + "id": "x-ai/grok-4.20-beta", + "name": "Grok 4.20 Beta", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "supports_video": true, - "context_window": 131072, - "max_output_tokens": 131072, + "supports_web_search": true, + "context_window": 2000000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "z-ai/glm-4.7", - "name": "GLM 4.7", + "id": "x-ai/grok-code-fast-1", + "name": "Grok Code Fast 1", "provider": "openrouter", "api": "openai-completions", "supports_vision": false, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "context_window": 202752, + "supports_web_search": true, + "context_window": 256000, + "max_output_tokens": 10000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "z-ai/glm-5", - "name": "GLM 5", + "id": "z-ai/glm-5-turbo", + "name": "GLM 5 Turbo", "provider": "openrouter", "api": "openai-completions", "supports_vision": false, @@ -1091,6 +441,7 @@ "supports_reasoning": true, "supports_web_search": false, "context_window": 202752, + "max_output_tokens": 131072, "available_tools": [ "function_calling" ] @@ -1098,8 +449,8 @@ ], "aliases": { "beeper/default": "anthropic/claude-opus-4.6", - "beeper/fast": "openai/gpt-5-mini", - "beeper/reasoning": "openai/gpt-5.2", - "beeper/smart": "openai/gpt-5.2" + "beeper/fast": "openai/gpt-5.4-mini", + "beeper/reasoning": "openai/o3", + "beeper/smart": "openai/gpt-5.4" } } From dc94f1b74ef73df81f52692ef329da85bf4cb98c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Mon, 30 Mar 2026 02:58:44 +0200 Subject: [PATCH 19/23] Add Gemma 2 27B model (google/gemma-2-27b-it) Register the new model "google/gemma-2-27b-it" (Gemma 2 27B). Added model metadata to pkg/ai/beeper_models.json, added the mapping in cmd/generate-models/main.go, and regenerated bridges/ai/beeper_models_generated.go. Model metadata: provider=openrouter, api=openai-completions, context_window=8192, max_output_tokens=2048, and no vision/tool-calling/reasoning/web-search support. --- bridges/ai/beeper_models_generated.go | 19 ++++++++++++++++++- cmd/generate-models/main.go | 1 + pkg/ai/beeper_models.json | 12 ++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/bridges/ai/beeper_models_generated.go b/bridges/ai/beeper_models_generated.go index bf5e0348..ecb15657 100644 --- a/bridges/ai/beeper_models_generated.go +++ b/bridges/ai/beeper_models_generated.go @@ -1,5 +1,5 @@ // Code generated by generate-models. DO NOT EDIT. -// Generated at: 2026-03-30T00:25:29Z +// Generated at: 2026-03-30T00:58:12Z package ai @@ -146,6 +146,23 @@ var ModelManifest = struct { MaxOutputTokens: 65536, AvailableTools: []string{ToolFunctionCalling}, }, + "google/gemma-2-27b-it": { + ID: "google/gemma-2-27b-it", + Name: "Gemma 2 27B", + Provider: "openrouter", + API: "openai-completions", + SupportsVision: false, + SupportsToolCalling: false, + SupportsReasoning: false, + SupportsWebSearch: false, + SupportsImageGen: false, + SupportsAudio: false, + SupportsVideo: false, + SupportsPDF: false, + ContextWindow: 8192, + MaxOutputTokens: 2048, + AvailableTools: []string{}, + }, "meta-llama/llama-4-maverick": { ID: "meta-llama/llama-4-maverick", Name: "Llama 4 Maverick", diff --git a/cmd/generate-models/main.go b/cmd/generate-models/main.go index 76f78d58..03cc7572 100644 --- a/cmd/generate-models/main.go +++ b/cmd/generate-models/main.go @@ -34,6 +34,7 @@ var modelConfig = struct { "google/gemini-2.5-flash-lite": "Gemini 2.5 Flash Lite", "google/gemini-2.5-pro": "Gemini 2.5 Pro", "google/gemini-3-flash-preview": "Gemini 3 Flash", + "google/gemma-2-27b-it": "Gemma 2 27B", "meta-llama/llama-4-maverick": "Llama 4 Maverick", "minimax/minimax-m2.7": "MiniMax M2.7", "mistralai/devstral-2512": "Devstral 2", diff --git a/pkg/ai/beeper_models.json b/pkg/ai/beeper_models.json index 2ac68569..4cd1b0f5 100644 --- a/pkg/ai/beeper_models.json +++ b/pkg/ai/beeper_models.json @@ -134,6 +134,18 @@ "function_calling" ] }, + { + "id": "google/gemma-2-27b-it", + "name": "Gemma 2 27B", + "provider": "openrouter", + "api": "openai-completions", + "supports_vision": false, + "supports_tool_calling": false, + "supports_reasoning": false, + "supports_web_search": false, + "context_window": 8192, + "max_output_tokens": 2048 + }, { "id": "meta-llama/llama-4-maverick", "name": "Llama 4 Maverick", From 712e4425e6006ec417bd352ff11f7bda790e72a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Mon, 30 Mar 2026 02:58:47 +0200 Subject: [PATCH 20/23] Update beeper_models_manifest_test.go --- bridges/ai/beeper_models_manifest_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridges/ai/beeper_models_manifest_test.go b/bridges/ai/beeper_models_manifest_test.go index 4d1409a1..3abdeb6e 100644 --- a/bridges/ai/beeper_models_manifest_test.go +++ b/bridges/ai/beeper_models_manifest_test.go @@ -12,6 +12,7 @@ func TestModelManifestMatchesOpenRouterAllowlist(t *testing.T) { "google/gemini-2.5-flash-lite": {}, "google/gemini-2.5-pro": {}, "google/gemini-3-flash-preview": {}, + "google/gemma-2-27b-it": {}, "meta-llama/llama-4-maverick": {}, "minimax/minimax-m2.7": {}, "mistralai/devstral-2512": {}, From e60b341e575a7f08e97f0256a8454164abc7a431 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Mon, 30 Mar 2026 03:21:42 +0200 Subject: [PATCH 21/23] Extract follow-up test helper; remove unused wrappers Add getFollowUpMessagesForTest in agent_loop_steering_test.go and update tests to call it instead of the AIClient method. Remove the oc.getFollowUpMessages implementation from pending_queue.go and delete the unused registerRoomRunPendingItem wrapper in room_runs.go (keep the locked registerRoomRunPendingItemLocked). These changes reduce API surface on AIClient and consolidate test-only logic into the test file. --- bridges/ai/agent_loop_steering_test.go | 43 ++++++++++++++++++++------ bridges/ai/pending_queue.go | 26 ---------------- bridges/ai/room_runs.go | 10 ------ 3 files changed, 33 insertions(+), 46 deletions(-) diff --git a/bridges/ai/agent_loop_steering_test.go b/bridges/ai/agent_loop_steering_test.go index e3be7ee7..74776b61 100644 --- a/bridges/ai/agent_loop_steering_test.go +++ b/bridges/ai/agent_loop_steering_test.go @@ -9,6 +9,29 @@ import ( airuntime "github.com/beeper/agentremote/pkg/runtime" ) +func getFollowUpMessagesForTest(oc *AIClient, roomID id.RoomID) []PromptMessage { + if oc == nil || roomID == "" { + return nil + } + snapshot := oc.getQueueSnapshot(roomID) + if snapshot == nil { + return nil + } + behavior := airuntime.ResolveQueueBehavior(snapshot.mode) + if !behavior.Followup { + return nil + } + candidate, _ := oc.takePendingQueueDispatchCandidate(roomID, true) + if candidate == nil || len(candidate.items) == 0 { + return nil + } + _, prompt, ok := preparePendingQueueDispatchCandidate(candidate) + if !ok { + return nil + } + return buildSteeringPromptMessages([]string{prompt}) +} + func TestGetSteeringMessages_FiltersAndDrainsQueue(t *testing.T) { roomID := id.RoomID("!room:example.com") oc := &AIClient{ @@ -78,7 +101,7 @@ func TestGetFollowUpMessages_ConsumesSingleQueuedTextMessage(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) + messages := getFollowUpMessagesForTest(oc, roomID) if len(messages) != 1 || messages[0].Role != PromptRoleUser || messages[0].Text() != "follow up" { t.Fatalf("unexpected follow-up messages: %#v", messages) } @@ -101,7 +124,7 @@ func TestGetFollowUpMessages_CollectsQueuedTextMessages(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) + messages := getFollowUpMessagesForTest(oc, roomID) if len(messages) != 1 || messages[0].Role != PromptRoleUser { t.Fatalf("expected one combined follow-up message, got %#v", messages) } @@ -127,7 +150,7 @@ func TestGetFollowUpMessages_CollectSummaryIsConsumed(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) + messages := getFollowUpMessagesForTest(oc, roomID) if len(messages) != 1 || messages[0].Role != PromptRoleUser { t.Fatalf("expected one combined follow-up message, got %#v", messages) } @@ -135,7 +158,7 @@ func TestGetFollowUpMessages_CollectSummaryIsConsumed(t *testing.T) { t.Fatalf("unexpected combined follow-up prompt with summary: %q", messages[0].Text()) } - if again := oc.getFollowUpMessages(roomID); len(again) != 0 { + if again := getFollowUpMessagesForTest(oc, roomID); len(again) != 0 { t.Fatalf("expected collect summary to be consumed, got %#v", again) } if snapshot := oc.getQueueSnapshot(roomID); snapshot != nil { @@ -159,7 +182,7 @@ func TestGetFollowUpMessages_UsesSyntheticSummaryPrompt(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) + messages := getFollowUpMessagesForTest(oc, roomID) if len(messages) != 1 || messages[0].Role != PromptRoleUser { t.Fatalf("expected one synthetic follow-up message, got %#v", messages) } @@ -184,7 +207,7 @@ func TestGetFollowUpMessages_SyntheticSummaryIsConsumedBeforeLatestMessage(t *te }, } - first := oc.getFollowUpMessages(roomID) + first := getFollowUpMessagesForTest(oc, roomID) if len(first) != 1 || first[0].Role != PromptRoleUser { t.Fatalf("expected one synthetic follow-up message, got %#v", first) } @@ -192,7 +215,7 @@ func TestGetFollowUpMessages_SyntheticSummaryIsConsumedBeforeLatestMessage(t *te t.Fatalf("unexpected first synthetic follow-up prompt: %q", first[0].Text()) } - second := oc.getFollowUpMessages(roomID) + second := getFollowUpMessagesForTest(oc, roomID) if len(second) != 1 || second[0].Role != PromptRoleUser { t.Fatalf("expected queued latest message after summary, got %#v", second) } @@ -200,7 +223,7 @@ func TestGetFollowUpMessages_SyntheticSummaryIsConsumedBeforeLatestMessage(t *te t.Fatalf("expected latest queued message after consuming summary, got %q", second[0].Text()) } - if third := oc.getFollowUpMessages(roomID); len(third) != 0 { + if third := getFollowUpMessagesForTest(oc, roomID); len(third) != 0 { t.Fatalf("expected queue to be drained after latest message, got %#v", third) } } @@ -218,7 +241,7 @@ func TestGetFollowUpMessages_LeavesNonTextQueueItemsForBacklogProcessing(t *test }, } - messages := oc.getFollowUpMessages(roomID) + messages := getFollowUpMessagesForTest(oc, roomID) if len(messages) != 0 { t.Fatalf("expected non-text follow-up to stay queued, got %#v", messages) } @@ -240,7 +263,7 @@ func TestGetFollowUpMessages_LeavesNonFollowupQueueUntouched(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) + messages := getFollowUpMessagesForTest(oc, roomID) if len(messages) != 0 { t.Fatalf("expected no follow-up messages for non-followup mode, got %#v", messages) } diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index 3332f3d8..a952bdaa 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -351,32 +351,6 @@ func buildSteeringPromptMessages(prompts []string) []PromptMessage { return messages } -func (oc *AIClient) getFollowUpMessages(roomID id.RoomID) []PromptMessage { - if oc == nil || roomID == "" { - return nil - } - snapshot := oc.getQueueSnapshot(roomID) - if snapshot == nil { - return nil - } - behavior := airuntime.ResolveQueueBehavior(snapshot.mode) - if !behavior.Followup { - return nil - } - candidate, _ := oc.takePendingQueueDispatchCandidate(roomID, true) - if candidate == nil || len(candidate.items) == 0 { - return nil - } - for _, item := range candidate.items { - oc.registerRoomRunPendingItem(roomID, item) - } - _, prompt, ok := preparePendingQueueDispatchCandidate(candidate) - if !ok { - return nil - } - return buildSteeringPromptMessages([]string{prompt}) -} - func (oc *AIClient) markQueueDraining(roomID id.RoomID) bool { oc.pendingQueuesMu.Lock() defer oc.pendingQueuesMu.Unlock() diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index f13b0a00..64071164 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -117,16 +117,6 @@ func (oc *AIClient) enqueueSteerQueue(roomID id.RoomID, item pendingQueueItem) b return true } -func (oc *AIClient) registerRoomRunPendingItem(roomID id.RoomID, item pendingQueueItem) { - run := oc.getRoomRun(roomID) - if run == nil { - return - } - run.mu.Lock() - defer run.mu.Unlock() - oc.registerRoomRunPendingItemLocked(run, item) -} - func (oc *AIClient) registerRoomRunPendingItemLocked(run *roomRunState, item pendingQueueItem) { if run == nil { return From 9d200938b11e59fe4bb1d3917cdbb03cf14b7aaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Mon, 30 Mar 2026 07:57:19 +0200 Subject: [PATCH 22/23] Improve robustness and add tests across AI bridges Several robustness fixes, behavior tweaks, and tests across the AI bridge code paths: - Merge fallback DesktopAPI token into the default instance when default has no token and add unit test. - Add nil checks and clearer errors in image generation paths (provider resolution and OpenAI images base URL) and add tests for missing login metadata. - Normalize provider lookup in ModelsConfig.Provider to match keys case/whitespace and add a unit test. - Preserve block ordering in prompt building and normalize blank tool-call arguments to "{}"; add test. - Clone []byte values when cloning pending event raw values. - Implement rewriting of trimmed tool-result blocks to preserve block structure when truncating oversized tool results; add test. - Include pendingSteeringPrompts in agent loop/turn completion checks. - Improve pdftotext error reporting when the tool is missing. - Make dummybridge onDisconnect a no-op with an unused param. - Log errors when OpenClaw session sync fails (instead of silently ignoring the error). - Expand example config with compaction/overflow docs and defaults. - Fix connector init hook test to assert the passed bridge pointer and add test for ResolveCommandPrefix trimming behavior. - Tighten SendAIRoomInfo to require portal.Bridge and portal.Bridge.Bot. - Rename ZAIGLM model constant and update agent presets and fallbacks to use the new constant. - Consolidate agent ID extraction in memory integration (use agentIDFromEventMeta) and add nil guards for meta in several places. - Make ZerologFromHost return Nop when host is nil and trim configured command prefix before returning. - Minor test additions and adjustments across multiple AI bridge tests and helpers. These changes aim to reduce panics on nil inputs, preserve data shapes when trimming content, and increase test coverage for edge cases. --- bridges/ai/desktop_api_sessions.go | 6 ++- bridges/ai/desktop_api_sessions_test.go | 43 +++++++++++++++++++ bridges/ai/image_generation_tool.go | 15 +++++++ .../image_generation_tool_magic_proxy_test.go | 14 +++++- bridges/ai/integrations.go | 1 + bridges/ai/integrations_config.go | 8 +++- bridges/ai/integrations_config_test.go | 16 +++++++ bridges/ai/magic_proxy_test.go | 9 ++++ .../media_understanding_runner_openai_test.go | 4 +- bridges/ai/pending_event.go | 2 + bridges/ai/prompt_builder.go | 2 +- bridges/ai/prompt_context_local.go | 6 ++- bridges/ai/prompt_context_local_test.go | 23 ++++++++++ bridges/ai/prompt_projection_local.go | 2 + bridges/ai/response_retry.go | 42 +++++++++++++++++- bridges/ai/response_retry_test.go | 35 +++++++++++++++ bridges/ai/streaming_responses_api.go | 6 +-- bridges/ai/text_files.go | 3 ++ bridges/dummybridge/bridge.go | 4 +- bridges/openclaw/manager.go | 4 +- config.example.yaml | 9 ++++ connector_builder_test.go | 7 ++- helpers.go | 2 +- pkg/agents/beeper.go | 2 +- pkg/agents/boss.go | 2 +- pkg/agents/presets.go | 2 +- pkg/integrations/memory/integration.go | 18 ++++---- pkg/integrations/runtime/helpers.go | 2 +- sdk/connector_helpers.go | 5 ++- sdk/connector_hooks_test.go | 9 ++++ sdk/conversation_state_test.go | 4 +- 31 files changed, 274 insertions(+), 33 deletions(-) create mode 100644 bridges/ai/desktop_api_sessions_test.go create mode 100644 bridges/ai/integrations_config_test.go create mode 100644 bridges/ai/prompt_context_local_test.go diff --git a/bridges/ai/desktop_api_sessions.go b/bridges/ai/desktop_api_sessions.go index 2911d6a1..29121000 100644 --- a/bridges/ai/desktop_api_sessions.go +++ b/bridges/ai/desktop_api_sessions.go @@ -173,8 +173,10 @@ func (oc *AIClient) desktopAPIInstances() map[string]DesktopAPIInstance { instances[key] = instance } if token := strings.TrimSpace(creds.ServiceTokens.DesktopAPI); token != "" { - if _, ok := instances[desktopDefaultInstance]; !ok { - instances[desktopDefaultInstance] = DesktopAPIInstance{Token: token} + instance := instances[desktopDefaultInstance] + if strings.TrimSpace(instance.Token) == "" { + instance.Token = token + instances[desktopDefaultInstance] = instance } } return instances diff --git a/bridges/ai/desktop_api_sessions_test.go b/bridges/ai/desktop_api_sessions_test.go new file mode 100644 index 00000000..3b97cb4e --- /dev/null +++ b/bridges/ai/desktop_api_sessions_test.go @@ -0,0 +1,43 @@ +package ai + +import ( + "testing" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +func TestDesktopAPIInstancesMergesFallbackTokenIntoDefaultInstance(t *testing.T) { + client := &AIClient{ + UserLogin: &bridgev2.UserLogin{ + UserLogin: &database.UserLogin{ + ID: networkid.UserLoginID("login"), + Metadata: &UserLoginMetadata{ + Credentials: &LoginCredentials{ + ServiceTokens: &ServiceTokens{ + DesktopAPI: "fallback-token", + DesktopAPIInstances: map[string]DesktopAPIInstance{ + "default": {BaseURL: "https://desktop.example"}, + }, + }, + }, + }, + }, + Log: zerolog.Nop(), + }, + } + + instances := client.desktopAPIInstances() + got, ok := instances[desktopDefaultInstance] + if !ok { + t.Fatal("expected default desktop API instance") + } + if got.Token != "fallback-token" { + t.Fatalf("expected fallback token to be merged, got %#v", got) + } + if got.BaseURL != "https://desktop.example" { + t.Fatalf("expected base URL to be preserved, got %#v", got) + } +} diff --git a/bridges/ai/image_generation_tool.go b/bridges/ai/image_generation_tool.go index 1df8a926..79e2cda5 100644 --- a/bridges/ai/image_generation_tool.go +++ b/bridges/ai/image_generation_tool.go @@ -153,6 +153,9 @@ func readStringSlice(args map[string]any, key string) []string { } func resolveImageGenProvider(req imageGenRequest, btc *BridgeToolContext) (imageGenProvider, error) { + if btc == nil || btc.Client == nil || btc.Client.UserLogin == nil { + return "", errors.New("image generation is not available for this login") + } provider := strings.ToLower(strings.TrimSpace(req.Provider)) if provider != "" { switch provider { @@ -177,6 +180,9 @@ func resolveImageGenProvider(req imageGenRequest, btc *BridgeToolContext) (image } loginMeta := loginMetadata(btc.Client.UserLogin) + if loginMeta == nil { + return "", errors.New("image generation is not available for this login") + } inferredProvider := inferProviderFromModel(req.Model) if inferredProvider != "" { switch inferredProvider { @@ -251,6 +257,9 @@ func supportsOpenAIImageGen(btc *BridgeToolContext) bool { return false } loginMeta := loginMetadata(btc.Client.UserLogin) + if loginMeta == nil { + return false + } switch loginMeta.Provider { case ProviderOpenAI, ProviderMagicProxy: if loginMeta.Provider == ProviderMagicProxy { @@ -469,7 +478,13 @@ func isAllowedValue(value string, allowed map[string]bool) bool { } func buildOpenAIImagesBaseURL(btc *BridgeToolContext) (string, error) { + if btc == nil || btc.Client == nil || btc.Client.UserLogin == nil || btc.Client.UserLogin.Metadata == nil { + return "", errors.New("openai image generation not available for this provider") + } loginMeta := loginMetadata(btc.Client.UserLogin) + if loginMeta == nil { + return "", errors.New("openai image generation not available for this provider") + } switch loginMeta.Provider { case ProviderOpenAI: base := btc.Client.connector.resolveOpenAIBaseURL() diff --git a/bridges/ai/image_generation_tool_magic_proxy_test.go b/bridges/ai/image_generation_tool_magic_proxy_test.go index ef8ffdcf..7cf53751 100644 --- a/bridges/ai/image_generation_tool_magic_proxy_test.go +++ b/bridges/ai/image_generation_tool_magic_proxy_test.go @@ -1,6 +1,8 @@ package ai -import "testing" +import ( + "testing" +) func TestResolveImageGenProviderMagicProxyPrefersOpenAIForSimplePrompts(t *testing.T) { meta := &UserLoginMetadata{ @@ -170,3 +172,13 @@ func TestBuildGeminiBaseURLMagicProxy(t *testing.T) { t.Fatalf("unexpected base url: %q", baseURL) } } + +func TestResolveImageGenProviderRejectsMissingLoginMetadata(t *testing.T) { + btc := &BridgeToolContext{ + Client: &AIClient{}, + } + + if _, err := resolveImageGenProvider(imageGenRequest{Prompt: "cat"}, btc); err == nil { + t.Fatal("expected missing login metadata to be rejected") + } +} diff --git a/bridges/ai/integrations.go b/bridges/ai/integrations.go index 86db2db3..69fceef8 100644 --- a/bridges/ai/integrations.go +++ b/bridges/ai/integrations.go @@ -533,6 +533,7 @@ func notifyIntegrationFileChanged(ctx context.Context, path string) { btc.Client.emitIntegrationFileChanged(ctx, btc.Portal, meta, path) } +// purgeLoginIntegrations keeps the login argument for parity with the logout cleanup call site. func (oc *AIClient) purgeLoginIntegrations(ctx context.Context, _ *bridgev2.UserLogin, bridgeID, loginID string) { if oc == nil || oc.purgeRegistry == nil { return diff --git a/bridges/ai/integrations_config.go b/bridges/ai/integrations_config.go index fad865eb..59fc1de4 100644 --- a/bridges/ai/integrations_config.go +++ b/bridges/ai/integrations_config.go @@ -448,7 +448,13 @@ func (cfg *ModelsConfig) Provider(name string) ModelProviderConfig { if cfg == nil || len(cfg.Providers) == 0 { return ModelProviderConfig{} } - return cfg.Providers[strings.ToLower(strings.TrimSpace(name))] + normalized := strings.ToLower(strings.TrimSpace(name)) + for key, provider := range cfg.Providers { + if strings.ToLower(strings.TrimSpace(key)) == normalized { + return provider + } + } + return ModelProviderConfig{} } // ModelDefinitionConfig defines a model entry for catalog seeding. diff --git a/bridges/ai/integrations_config_test.go b/bridges/ai/integrations_config_test.go new file mode 100644 index 00000000..d1ca6e06 --- /dev/null +++ b/bridges/ai/integrations_config_test.go @@ -0,0 +1,16 @@ +package ai + +import "testing" + +func TestModelsConfigProviderMatchesNormalizedKeys(t *testing.T) { + cfg := &ModelsConfig{ + Providers: map[string]ModelProviderConfig{ + " OpenAI ": {APIKey: "tok"}, + }, + } + + got := cfg.Provider("openai") + if got.APIKey != "tok" { + t.Fatalf("expected normalized provider lookup to match, got %#v", got) + } +} diff --git a/bridges/ai/magic_proxy_test.go b/bridges/ai/magic_proxy_test.go index f231482c..35550a09 100644 --- a/bridges/ai/magic_proxy_test.go +++ b/bridges/ai/magic_proxy_test.go @@ -50,9 +50,15 @@ func TestResolveServiceConfigMagicProxyUsesJoinedPaths(t *testing.T) { if got := services[serviceOpenRouter].BaseURL; got != "https://bai.bt.hn/team/proxy/openrouter/v1" { t.Fatalf("unexpected openrouter base URL: %q", got) } + if got := services[serviceOpenRouter].APIKey; got != "tok" { + t.Fatalf("unexpected openrouter api key: %q", got) + } if got := services[serviceOpenAI].BaseURL; got != "https://bai.bt.hn/team/proxy/openai/v1" { t.Fatalf("unexpected openai base URL: %q", got) } + if got := services[serviceOpenAI].APIKey; got != "tok" { + t.Fatalf("unexpected openai api key: %q", got) + } if got := services[serviceGemini].BaseURL; got != "https://bai.bt.hn/team/proxy/gemini/v1beta" { t.Fatalf("unexpected gemini base URL: %q", got) } @@ -76,6 +82,9 @@ func TestResolveServiceConfigMagicProxyNoDuplicateOpenRouterPath(t *testing.T) { if strings.Count(base, "/openrouter/v1") != 1 { t.Fatalf("openrouter path duplicated: %q", base) } + if got := services[serviceOpenRouter].APIKey; got != "tok" { + t.Fatalf("unexpected openrouter api key: %q", got) + } if got := services[serviceExa].BaseURL; got != "https://bai.bt.hn/team/proxy/exa" { t.Fatalf("unexpected exa base URL: %q", got) } diff --git a/bridges/ai/media_understanding_runner_openai_test.go b/bridges/ai/media_understanding_runner_openai_test.go index 2d784cf5..14d595b5 100644 --- a/bridges/ai/media_understanding_runner_openai_test.go +++ b/bridges/ai/media_understanding_runner_openai_test.go @@ -57,7 +57,9 @@ func TestResolveOpenRouterMediaConfigUsesEntryOverrides(t *testing.T) { t.Setenv("OPENROUTER_API_KEY_SPECIAL_PROFILE", "entry-key") client := newMediaTestClient(&UserLoginMetadata{Provider: ProviderOpenAI}, &OpenAIConnector{ - Config: Config{}, + Config: Config{ + Agents: &AgentsConfig{Defaults: &AgentDefaultsConfig{PDFEngine: "mistral-ocr"}}, + }, }) cfg := &MediaUnderstandingConfig{ diff --git a/bridges/ai/pending_event.go b/bridges/ai/pending_event.go index 6a357a65..e2143ba9 100644 --- a/bridges/ai/pending_event.go +++ b/bridges/ai/pending_event.go @@ -47,6 +47,8 @@ func clonePendingRawValue(v any) any { return clonePendingRawMap(typed) case []any: return clonePendingRawSlice(typed) + case []byte: + return append([]byte(nil), typed...) default: return v } diff --git a/bridges/ai/prompt_builder.go b/bridges/ai/prompt_builder.go index 5a4d1774..e55ccad8 100644 --- a/bridges/ai/prompt_builder.go +++ b/bridges/ai/prompt_builder.go @@ -225,10 +225,10 @@ func (oc *AIClient) buildPromptContextForTurn( } blocks := make([]PromptBlock, 0, len(leadingBlocks)+1) + blocks = append(blocks, leadingBlocks...) if strings.TrimSpace(text) != "" { blocks = append(blocks, PromptBlock{Type: PromptBlockText, Text: text}) } - blocks = append(blocks, leadingBlocks...) base.Messages = append(base.Messages, PromptMessage{ Role: PromptRoleUser, Blocks: blocks, diff --git a/bridges/ai/prompt_context_local.go b/bridges/ai/prompt_context_local.go index b8a2f550..8514aa2c 100644 --- a/bridges/ai/prompt_context_local.go +++ b/bridges/ai/prompt_context_local.go @@ -205,12 +205,16 @@ func promptAssistantToChatMessage(msg PromptMessage) *openai.ChatCompletionAssis if strings.TrimSpace(block.ToolCallID) == "" || strings.TrimSpace(block.ToolName) == "" { continue } + args := strings.TrimSpace(block.ToolCallArguments) + if args == "" { + args = "{}" + } toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallUnionParam{ OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ ID: block.ToolCallID, Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ Name: block.ToolName, - Arguments: block.ToolCallArguments, + Arguments: args, }, }, }) diff --git a/bridges/ai/prompt_context_local_test.go b/bridges/ai/prompt_context_local_test.go new file mode 100644 index 00000000..0d6814ea --- /dev/null +++ b/bridges/ai/prompt_context_local_test.go @@ -0,0 +1,23 @@ +package ai + +import "testing" + +func TestPromptAssistantToChatMessageNormalizesBlankToolArguments(t *testing.T) { + msg := PromptMessage{ + Role: PromptRoleAssistant, + Blocks: []PromptBlock{{ + Type: PromptBlockToolCall, + ToolCallID: "call_123", + ToolName: "search", + ToolCallArguments: " ", + }}, + } + + assistant := promptAssistantToChatMessage(msg) + if assistant == nil || len(assistant.ToolCalls) != 1 || assistant.ToolCalls[0].OfFunction == nil { + t.Fatalf("expected one function tool call, got %#v", assistant) + } + if got := assistant.ToolCalls[0].OfFunction.Function.Arguments; got != "{}" { + t.Fatalf("expected blank tool arguments to normalize to {}, got %q", got) + } +} diff --git a/bridges/ai/prompt_projection_local.go b/bridges/ai/prompt_projection_local.go index a31592e4..3e03119c 100644 --- a/bridges/ai/prompt_projection_local.go +++ b/bridges/ai/prompt_projection_local.go @@ -98,6 +98,8 @@ func promptMessagesFromTurnData(td sdk.TurnData) []PromptMessage { } } +// turnDataFromUserPromptMessages intentionally projects only the latest user +// message because callers pass a single-message tail via promptTail(..., 1). func turnDataFromUserPromptMessages(messages []PromptMessage) (sdk.TurnData, bool) { if len(messages) == 0 { return sdk.TurnData{}, false diff --git a/bridges/ai/response_retry.go b/bridges/ai/response_retry.go index a140a893..a999ace5 100644 --- a/bridges/ai/response_retry.go +++ b/bridges/ai/response_retry.go @@ -472,12 +472,52 @@ func (oc *AIClient) truncateOversizedToolResultsForOverflow( if trimmed == content { continue } - out.Messages[i].Blocks = []PromptBlock{{Type: PromptBlockText, Text: trimmed}} + out.Messages[i].Blocks = rewriteTrimmedToolResultBlocks(msg.Blocks, trimmed) truncated++ } return out, truncated } +func rewriteTrimmedToolResultBlocks(blocks []PromptBlock, trimmed string) []PromptBlock { + if len(blocks) == 0 { + return []PromptBlock{{Type: PromptBlockText, Text: trimmed}} + } + remaining := trimmed + rewritten := make([]PromptBlock, 0, len(blocks)) + previousTextBlock := false + for _, block := range blocks { + if remaining == "" { + break + } + switch block.Type { + case PromptBlockText, PromptBlockThinking: + default: + continue + } + if block.Text == "" { + continue + } + if previousTextBlock && strings.HasPrefix(remaining, "\n") { + remaining = remaining[1:] + if remaining == "" { + break + } + } + take := len(block.Text) + if take > len(remaining) { + take = len(remaining) + } + block.Text = remaining[:take] + remaining = remaining[take:] + rewritten = append(rewritten, block) + previousTextBlock = true + } + if len(rewritten) == 0 { + return []PromptBlock{{Type: PromptBlockText, Text: trimmed}} + } + return rewritten +} + // emitCompactionStatus sends a compaction status event to the room func (oc *AIClient) emitCompactionStatus(ctx context.Context, portal *bridgev2.Portal, evt *CompactionEvent) { if portal == nil || portal.MXID == "" { diff --git a/bridges/ai/response_retry_test.go b/bridges/ai/response_retry_test.go index a6827ed2..aa909070 100644 --- a/bridges/ai/response_retry_test.go +++ b/bridges/ai/response_retry_test.go @@ -1,6 +1,7 @@ package ai import ( + "strings" "testing" "github.com/rs/zerolog" @@ -173,3 +174,37 @@ func TestPruningPostCompactionRefreshPrompt_Defaults(t *testing.T) { t.Fatal("expected non-empty post-compaction refresh prompt") } } + +func TestTruncateOversizedToolResultsForOverflowPreservesBlockStructure(t *testing.T) { + client := newPruningTestClient(&airuntime.PruningConfig{ + SoftTrimMaxChars: 20, + SoftTrimHeadChars: 8, + SoftTrimTailChars: 8, + }, ProviderOpenAI) + + prompt := PromptContext{ + Messages: []PromptMessage{{ + Role: PromptRoleToolResult, + Blocks: []PromptBlock{ + {Type: PromptBlockText, Text: strings.Repeat("first-block-", 12)}, + {Type: PromptBlockText, Text: strings.Repeat("second-block-", 12)}, + }, + }}, + } + + got, truncated := client.truncateOversizedToolResultsForOverflow(prompt, 0) + if truncated != 1 { + t.Fatalf("expected one truncated tool result, got %d", truncated) + } + if len(got.Messages) != 1 || len(got.Messages[0].Blocks) == 0 { + t.Fatalf("expected rewritten blocks, got %#v", got.Messages) + } + for _, block := range got.Messages[0].Blocks { + if block.Type != PromptBlockText { + t.Fatalf("expected text blocks to be preserved, got %#v", got.Messages[0].Blocks) + } + } + if got.Messages[0].Text() != airuntime.SoftTrimToolResult(prompt.Messages[0].Text(), client.pruningConfigOrDefault()) { + t.Fatalf("expected trimmed text to match soft-trim output, got %q", got.Messages[0].Text()) + } +} diff --git a/bridges/ai/streaming_responses_api.go b/bridges/ai/streaming_responses_api.go index 20b682d8..861907e3 100644 --- a/bridges/ai/streaming_responses_api.go +++ b/bridges/ai/streaming_responses_api.go @@ -115,7 +115,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( return false, nil, &PreDeltaError{Err: err} } } else { - if len(state.pendingFunctionOutputs) == 0 && len(state.pendingMcpApprovals) == 0 { + if len(state.pendingFunctionOutputs) == 0 && len(state.pendingMcpApprovals) == 0 && len(state.pendingSteeringPrompts) == 0 { return false, nil, nil } if round > maxAgentLoopToolTurns { @@ -166,10 +166,10 @@ func (a *responsesTurnAdapter) RunAgentTurn( return false, cle, err } if done { - return state != nil && (len(state.pendingFunctionOutputs) > 0 || len(state.pendingMcpApprovals) > 0), nil, nil + return state != nil && (len(state.pendingFunctionOutputs) > 0 || len(state.pendingMcpApprovals) > 0 || len(state.pendingSteeringPrompts) > 0), nil, nil } - return state != nil && (len(state.pendingFunctionOutputs) > 0 || len(state.pendingMcpApprovals) > 0), nil, nil + return state != nil && (len(state.pendingFunctionOutputs) > 0 || len(state.pendingMcpApprovals) > 0 || len(state.pendingSteeringPrompts) > 0), nil, nil } func (a *responsesTurnAdapter) FinalizeAgentLoop(ctx context.Context) { diff --git a/bridges/ai/text_files.go b/bridges/ai/text_files.go index 3690b1a7..43515f1d 100644 --- a/bridges/ai/text_files.go +++ b/bridges/ai/text_files.go @@ -216,6 +216,9 @@ func (oc *AIClient) downloadPDFFile(ctx context.Context, mediaURL string, encryp cmd := exec.CommandContext(ctx, "pdftotext", "-layout", "-enc", "UTF-8", inputPath, "-") output, err := cmd.Output() if err != nil { + if errors.Is(err, exec.ErrNotFound) { + return "", false, fmt.Errorf("pdftotext not found: install poppler-utils (or poppler) and ensure pdftotext is on PATH: %w", err) + } if exitErr, ok := err.(*exec.ExitError); ok { msg := strings.TrimSpace(string(exitErr.Stderr)) if msg != "" { diff --git a/bridges/dummybridge/bridge.go b/bridges/dummybridge/bridge.go index e9130798..656be714 100644 --- a/bridges/dummybridge/bridge.go +++ b/bridges/dummybridge/bridge.go @@ -57,9 +57,7 @@ func (dc *DummyBridgeConnector) onConnect(ctx context.Context, info *bridgesdk.L }, nil } -func (dc *DummyBridgeConnector) onDisconnect(session *dummySession) { - _, _ = requireSession(session) -} +func (dc *DummyBridgeConnector) onDisconnect(_ *dummySession) {} func (dc *DummyBridgeConnector) getContactList(ctx context.Context, session *dummySession) ([]*bridgev2.ResolveIdentifierResponse, error) { dummy, err := requireSession(session) diff --git a/bridges/openclaw/manager.go b/bridges/openclaw/manager.go index 35e0124e..96afc01f 100644 --- a/bridges/openclaw/manager.go +++ b/bridges/openclaw/manager.go @@ -594,7 +594,9 @@ func (m *openClawManager) HandleMatrixMessage(ctx context.Context, msg *bridgev2 } if meta.OpenClawDMCreatedFromContact && meta.OpenClawSessionID == "" && isOpenClawSyntheticDMSessionKey(meta.OpenClawSessionKey) { go func() { - _ = m.syncSessions(m.client.BackgroundContext(ctx)) + if err := m.syncSessions(m.client.BackgroundContext(ctx)); err != nil { + m.client.Log().Debug().Err(err).Str("session_key", meta.OpenClawSessionKey).Msg("Failed to refresh OpenClaw sessions after synthetic DM message") + } }() } return &bridgev2.MatrixMessageResponse{Pending: true}, nil diff --git a/config.example.yaml b/config.example.yaml index 6fe29f82..3605cb70 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -184,21 +184,29 @@ network: fallbacks: [] pdf_engine: "mistral-ocr" compaction: + # cache-ttl keeps a cached compacted prompt for ttl; other modes can force + # more aggressive recomputation or fallback behavior in runtime defaults. mode: "cache-ttl" + # ttl applies to cached compaction snapshots when mode uses cache expiry. ttl: "1h" + # enabled gates soft trimming, hard clear, summarization, and token reserves. enabled: true soft_trim_ratio: 0.3 hard_clear_ratio: 0.5 + # keep a few recent assistant turns and require enough prunable text before trimming. keep_last_assistants: 3 min_prunable_chars: 50000 + # soft trim keeps head/tail context from oversized tool results before full compaction. soft_trim_max_chars: 4000 soft_trim_head_chars: 1500 soft_trim_tail_chars: 1500 hard_clear_enabled: true hard_clear_placeholder: "[Old tool result content cleared]" + # summarization condenses old history before hard clearing it entirely. summarization_enabled: true summarization_model: "openai/gpt-5.2" max_summary_tokens: 500 + # safeguard preserves recent history/tokens; alternative modes may trade fidelity for space. compaction_mode: "safeguard" keep_recent_tokens: 20000 max_history_share: 0.5 @@ -207,6 +215,7 @@ network: identifier_policy: "strict" post_compaction_refresh_prompt: "[Post-compaction context refresh]\nRe-anchor to the latest user intent and preserve unresolved tasks and identifiers." overflow_flush: + # overflow flush runs a last tool-only pass near the soft threshold before compaction. enabled: true soft_threshold_tokens: 4000 prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." diff --git a/connector_builder_test.go b/connector_builder_test.go index 135c155b..9b751f84 100644 --- a/connector_builder_test.go +++ b/connector_builder_test.go @@ -16,7 +16,12 @@ func TestConnectorBaseHookOrder(t *testing.T) { var order []string wantBridge := &bridgev2.Bridge{} conn := NewConnector(ConnectorSpec{ - Init: func(*bridgev2.Bridge) { order = append(order, "init") }, + Init: func(got *bridgev2.Bridge) { + if got != wantBridge { + t.Fatalf("expected init hook bridge %p, got %p", wantBridge, got) + } + order = append(order, "init") + }, Start: func(_ context.Context, got *bridgev2.Bridge) error { if got != wantBridge { t.Fatalf("expected start hook bridge %p, got %p", wantBridge, got) diff --git a/helpers.go b/helpers.go index dfd799e8..e403e3c2 100644 --- a/helpers.go +++ b/helpers.go @@ -450,7 +450,7 @@ func ApplyAgentRemoteBridgeInfo(content *event.BridgeEventContent, protocolID st } func SendAIRoomInfo(ctx context.Context, portal *bridgev2.Portal, aiKind string) bool { - if portal == nil || portal.MXID == "" { + if portal == nil || portal.MXID == "" || portal.Bridge == nil || portal.Bridge.Bot == nil { return false } if aiKind == "" { diff --git a/pkg/agents/beeper.go b/pkg/agents/beeper.go index 13253a5b..1e2a7e2b 100644 --- a/pkg/agents/beeper.go +++ b/pkg/agents/beeper.go @@ -19,7 +19,7 @@ var BeeperAIAgent = &AgentDefinition{ Fallbacks: []string{ ModelClaudeSonnet, ModelOpenAIGPT52, - ModelZAIGLM47, + ModelZAIGLM5Turbo, }, }, Tools: &toolpolicy.ToolPolicyConfig{Profile: toolpolicy.ProfileFull}, diff --git a/pkg/agents/boss.go b/pkg/agents/boss.go index deb1118e..d47c0b69 100644 --- a/pkg/agents/boss.go +++ b/pkg/agents/boss.go @@ -13,7 +13,7 @@ var BossAgent = &AgentDefinition{ Fallbacks: []string{ ModelClaudeSonnet, ModelOpenAIGPT52, - ModelZAIGLM47, + ModelZAIGLM5Turbo, }, }, Tools: &toolpolicy.ToolPolicyConfig{Profile: toolpolicy.ProfileBoss}, diff --git a/pkg/agents/presets.go b/pkg/agents/presets.go index 354f37a7..b16ad307 100644 --- a/pkg/agents/presets.go +++ b/pkg/agents/presets.go @@ -7,7 +7,7 @@ const ( ModelClaudeSonnet = "anthropic/claude-sonnet-4.6" ModelClaudeOpus = "anthropic/claude-opus-4.6" ModelOpenAIGPT52 = "openai/gpt-5.2" - ModelZAIGLM47 = "z-ai/glm-5-turbo" + ModelZAIGLM5Turbo = "z-ai/glm-5-turbo" ) // PresetAgents contains the default agent definitions: diff --git a/pkg/integrations/memory/integration.go b/pkg/integrations/memory/integration.go index e209a835..31281db3 100644 --- a/pkg/integrations/memory/integration.go +++ b/pkg/integrations/memory/integration.go @@ -70,7 +70,7 @@ func (i *Integration) ToolAvailability(_ context.Context, scope iruntime.ToolSco return false, false, iruntime.SourceGlobalDefault, "" } if scope.Meta != nil { - agentID := scope.Meta.AgentID() + agentID := i.agentIDFromEventMeta(scope.Meta) _, errMsg := i.getManager(agentID) if errMsg != "" { return true, false, iruntime.SourceProviderLimit, errMsg @@ -233,6 +233,9 @@ func (i *Integration) buildOverflowDeps() OverflowDeps { return i.host.EstimateTokens(prompt, model) }, AlreadyFlushed: func(call iruntime.ContextOverflowCall) bool { + if call.Meta == nil { + return false + } flushAtMs := toInt64(call.Meta.ModuleMetaValue("overflow_flush_at")) if flushAtMs == 0 { return false @@ -266,6 +269,9 @@ func (i *Integration) shouldInjectMemoryPromptContext(_ *bridgev2.Portal, _ irun } func (i *Integration) shouldBootstrapMemoryPromptContext(_ *bridgev2.Portal, meta iruntime.Meta) bool { + if meta == nil { + return false + } raw := meta.ModuleMetaValue("memory_bootstrap_at") if raw == nil { return true @@ -296,10 +302,7 @@ func (i *Integration) markMemoryPromptBootstrapped(ctx context.Context, portal * } func (i *Integration) readMemoryPromptSection(ctx context.Context, meta iruntime.Meta, path string) string { - agentID := "" - if meta != nil { - agentID = meta.AgentID() - } + agentID := i.agentIDFromEventMeta(meta) content, filePath, found, err := i.host.ReadTextFile(ctx, agentID, path) if err != nil || !found { return "" @@ -438,10 +441,7 @@ func (i *Integration) writeMemoryCommandFile( content string, maxBytes int, ) (string, error) { - agentID := "" - if scope.Meta != nil { - agentID = scope.Meta.AgentID() - } + agentID := i.agentIDFromEventMeta(scope.Meta) return i.host.WriteTextFile(ctx, scope.Portal, scope.Meta, agentID, mode, path, content, maxBytes) } diff --git a/pkg/integrations/runtime/helpers.go b/pkg/integrations/runtime/helpers.go index 8923537b..d63a138a 100644 --- a/pkg/integrations/runtime/helpers.go +++ b/pkg/integrations/runtime/helpers.go @@ -7,7 +7,7 @@ import ( ) // ZerologFromHost extracts a zerolog.Logger from a Host. -// Returns zerolog.Nop() if the underlying logger is not a zerolog.Logger. +// Returns zerolog.Nop() if the host is nil. func ZerologFromHost(host Host) zerolog.Logger { if host == nil { return zerolog.Nop() diff --git a/sdk/connector_helpers.go b/sdk/connector_helpers.go index 3583abf4..4fd0c7e8 100644 --- a/sdk/connector_helpers.go +++ b/sdk/connector_helpers.go @@ -40,8 +40,9 @@ func ApplyDefaultCommandPrefix(prefix *string, value string) { // ResolveCommandPrefix returns the configured prefix when present, otherwise the // bridge's declared default prefix without mutating configuration state. func ResolveCommandPrefix(prefix string, fallback string) string { - if strings.TrimSpace(prefix) != "" { - return prefix + trimmed := strings.TrimSpace(prefix) + if trimmed != "" { + return trimmed } return fallback } diff --git a/sdk/connector_hooks_test.go b/sdk/connector_hooks_test.go index 6670ad03..2207f7b0 100644 --- a/sdk/connector_hooks_test.go +++ b/sdk/connector_hooks_test.go @@ -216,4 +216,13 @@ func TestApprovalControllerUsesCustomHandler(t *testing.T) { } } +func TestResolveCommandPrefixTrimsConfiguredValue(t *testing.T) { + if got := ResolveCommandPrefix(" /ai ", "!fallback"); got != "/ai" { + t.Fatalf("expected trimmed configured prefix, got %q", got) + } + if got := ResolveCommandPrefix(" ", "!fallback"); got != "!fallback" { + t.Fatalf("expected fallback prefix, got %q", got) + } +} + var _ bridgev2.NetworkAPI = (*testSDKClient)(nil) diff --git a/sdk/conversation_state_test.go b/sdk/conversation_state_test.go index bbef53de..8cdc1830 100644 --- a/sdk/conversation_state_test.go +++ b/sdk/conversation_state_test.go @@ -81,7 +81,9 @@ func TestConversationStateRoundTripCarrierMetadata(t *testing.T) { } // saveConversationStateToGenericMetadata intentionally returns false here // because generic metadata doesn't support the carrier path. - _ = saveConversationStateToGenericMetadata(&holder, state) + if ok := saveConversationStateToGenericMetadata(&holder, state); ok { + t.Fatalf("expected generic metadata save to report unsupported carrier path") + } carrier.SetSDKPortalMetadata(&SDKPortalMetadata{Conversation: *state}) loaded, ok := carrier.GetSDKPortalMetadata(), carrier.GetSDKPortalMetadata() != nil if !ok || loaded == nil { From 93ca4f336c00456236df14096e89aeac00b3fbe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Mon, 30 Mar 2026 08:23:39 +0200 Subject: [PATCH 23/23] Make pending queue mutex-safe; normalize models Add synchronization to pending queue operations by introducing a mutex and locking critical sections to ensure thread-safety. Normalize model provider keys during YAML unmarshalling (ModelsConfig.UnmarshalYAML) so lookups are case/whitespace-insensitive and reject collisions; update Provider() to use the normalized map. Improve canonicalPromptToolArguments to preserve canonical JSON, JSON-decode strings when possible, and JSON-encode plain strings; add tests for this behavior. Refactor agent_loop_steering_test to add subtests for nil vs non-nil prompt handling and add tests for model normalization collision detection. --- bridges/ai/agent_loop_steering_test.go | 69 ++++++++++++++-------- bridges/ai/integrations_config.go | 33 ++++++++--- bridges/ai/integrations_config_test.go | 37 ++++++++++-- bridges/ai/pending_queue.go | 25 ++++++++ bridges/ai/prompt_projection_local.go | 28 +++++++++ bridges/ai/prompt_projection_local_test.go | 15 +++++ 6 files changed, 172 insertions(+), 35 deletions(-) create mode 100644 bridges/ai/prompt_projection_local_test.go diff --git a/bridges/ai/agent_loop_steering_test.go b/bridges/ai/agent_loop_steering_test.go index 74776b61..e2036588 100644 --- a/bridges/ai/agent_loop_steering_test.go +++ b/bridges/ai/agent_loop_steering_test.go @@ -274,31 +274,54 @@ func TestGetFollowUpMessages_LeavesNonFollowupQueueUntouched(t *testing.T) { func TestBuildContinuationParams_UsesPendingSteeringPromptsBeforeDrainingQueue(t *testing.T) { roomID := id.RoomID("!room:example.com") - oc := &AIClient{ - connector: &OpenAIConnector{}, - activeRoomRuns: map[id.RoomID]*roomRunState{ - roomID: { - steerQueue: []pendingQueueItem{ - {pending: pendingMessage{Type: pendingTypeText, MessageBody: "queue steer"}}, + newClient := func() *AIClient { + return &AIClient{ + connector: &OpenAIConnector{}, + activeRoomRuns: map[id.RoomID]*roomRunState{ + roomID: { + steerQueue: []pendingQueueItem{ + {pending: pendingMessage{Type: pendingTypeText, MessageBody: "queue steer"}}, + }, }, }, - }, + } } - state := &streamingState{roomID: roomID} - state.addPendingSteeringPrompts([]string{"pending steer"}) - prompt := PromptContext{} - params := oc.buildContinuationParams(context.Background(), &prompt, state, nil, nil, nil) - if len(params.Input.OfInputItemList) == 0 { - t.Fatal("expected continuation input to include stored steering prompt") - } - if pending := state.consumePendingSteeringPrompts(); len(pending) != 0 { - t.Fatalf("expected pending steering prompts to be consumed, got %#v", pending) - } - if len(prompt.Messages) == 0 { - t.Fatal("expected steering input to persist in canonical prompt even when history starts empty") - } - if snapshot := oc.getRoomRun(roomID); snapshot == nil || len(snapshot.steerQueue) != 1 { - t.Fatalf("expected queued steering item to remain available, got %#v", snapshot) - } + t.Run("non-nil prompt", func(t *testing.T) { + oc := newClient() + state := &streamingState{roomID: roomID} + state.addPendingSteeringPrompts([]string{"pending steer"}) + prompt := PromptContext{} + + params := oc.buildContinuationParams(context.Background(), &prompt, state, nil, nil, nil) + if len(params.Input.OfInputItemList) == 0 { + t.Fatal("expected continuation input to include stored steering prompt") + } + if pending := state.consumePendingSteeringPrompts(); len(pending) != 0 { + t.Fatalf("expected pending steering prompts to be consumed, got %#v", pending) + } + if len(prompt.Messages) == 0 { + t.Fatal("expected steering input to persist in canonical prompt even when history starts empty") + } + if snapshot := oc.getRoomRun(roomID); snapshot == nil || len(snapshot.steerQueue) != 1 { + t.Fatalf("expected queued steering item to remain available, got %#v", snapshot) + } + }) + + t.Run("nil prompt", func(t *testing.T) { + oc := newClient() + state := &streamingState{roomID: roomID} + state.addPendingSteeringPrompts([]string{"pending steer"}) + + params := oc.buildContinuationParams(context.Background(), nil, state, nil, nil, nil) + if len(params.Input.OfInputItemList) == 0 { + t.Fatal("expected continuation input to include stored steering prompt") + } + if pending := state.consumePendingSteeringPrompts(); len(pending) != 0 { + t.Fatalf("expected pending steering prompts to be consumed, got %#v", pending) + } + if snapshot := oc.getRoomRun(roomID); snapshot == nil || len(snapshot.steerQueue) != 1 { + t.Fatalf("expected queued steering item to remain available, got %#v", snapshot) + } + }) } diff --git a/bridges/ai/integrations_config.go b/bridges/ai/integrations_config.go index 59fc1de4..c36f6549 100644 --- a/bridges/ai/integrations_config.go +++ b/bridges/ai/integrations_config.go @@ -2,6 +2,7 @@ package ai import ( _ "embed" + "fmt" "strings" "time" @@ -444,17 +445,35 @@ type ModelSelectionConfig struct { Fallbacks []string `yaml:"fallbacks"` } +func (cfg *ModelsConfig) UnmarshalYAML(unmarshal func(any) error) error { + type rawModelsConfig ModelsConfig + var raw rawModelsConfig + if err := unmarshal(&raw); err != nil { + return err + } + if len(raw.Providers) > 0 { + normalizedProviders := make(map[string]ModelProviderConfig, len(raw.Providers)) + for key, provider := range raw.Providers { + normalized := strings.ToLower(strings.TrimSpace(key)) + if normalized == "" { + return fmt.Errorf("models.providers contains an empty provider key") + } + if _, exists := normalizedProviders[normalized]; exists { + return fmt.Errorf("models.providers contains duplicate provider key after normalization: %q", key) + } + normalizedProviders[normalized] = provider + } + raw.Providers = normalizedProviders + } + *cfg = ModelsConfig(raw) + return nil +} + func (cfg *ModelsConfig) Provider(name string) ModelProviderConfig { if cfg == nil || len(cfg.Providers) == 0 { return ModelProviderConfig{} } - normalized := strings.ToLower(strings.TrimSpace(name)) - for key, provider := range cfg.Providers { - if strings.ToLower(strings.TrimSpace(key)) == normalized { - return provider - } - } - return ModelProviderConfig{} + return cfg.Providers[strings.ToLower(strings.TrimSpace(name))] } // ModelDefinitionConfig defines a model entry for catalog seeding. diff --git a/bridges/ai/integrations_config_test.go b/bridges/ai/integrations_config_test.go index d1ca6e06..81147fc4 100644 --- a/bridges/ai/integrations_config_test.go +++ b/bridges/ai/integrations_config_test.go @@ -1,12 +1,21 @@ package ai -import "testing" +import ( + "strings" + "testing" + + "gopkg.in/yaml.v3" +) func TestModelsConfigProviderMatchesNormalizedKeys(t *testing.T) { - cfg := &ModelsConfig{ - Providers: map[string]ModelProviderConfig{ - " OpenAI ": {APIKey: "tok"}, - }, + var cfg ModelsConfig + if err := yaml.Unmarshal([]byte(` +mode: merge +providers: + " OpenAI ": + api_key: tok +`), &cfg); err != nil { + t.Fatalf("unmarshal config: %v", err) } got := cfg.Provider("openai") @@ -14,3 +23,21 @@ func TestModelsConfigProviderMatchesNormalizedKeys(t *testing.T) { t.Fatalf("expected normalized provider lookup to match, got %#v", got) } } + +func TestModelsConfigUnmarshalRejectsNormalizedKeyCollisions(t *testing.T) { + var cfg ModelsConfig + err := yaml.Unmarshal([]byte(` +mode: merge +providers: + OpenAI: + api_key: tok-1 + " openai ": + api_key: tok-2 +`), &cfg) + if err == nil { + t.Fatal("expected duplicate normalized provider keys to fail") + } + if !strings.Contains(err.Error(), "duplicate provider key") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index a952bdaa..aee69c52 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -4,6 +4,7 @@ import ( "context" "slices" "strings" + "sync" "time" "maunium.net/go/mautrix/bridgev2" @@ -24,6 +25,7 @@ type pendingQueueItem struct { } type pendingQueue struct { + mu sync.Mutex items []pendingQueueItem draining bool lastEnqueuedAt int64 @@ -57,6 +59,7 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe } oc.pendingQueues[roomID] = queue } else { + queue.mu.Lock() queue.mode = settings.Mode if settings.DebounceMs >= 0 { queue.debounceMs = settings.DebounceMs @@ -67,6 +70,7 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe if settings.DropPolicy != "" { queue.dropPolicy = settings.DropPolicy } + queue.mu.Unlock() } return queue } @@ -86,6 +90,8 @@ func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, if queue == nil { return false } + queue.mu.Lock() + defer queue.mu.Unlock() for _, existing := range queue.items { if pendingQueueItemsConflict(item, existing) { @@ -151,6 +157,8 @@ func (oc *AIClient) popQueueItems(roomID id.RoomID, count int) []pendingQueueIte if queue == nil || len(queue.items) == 0 || count <= 0 { return nil } + queue.mu.Lock() + defer queue.mu.Unlock() if count > len(queue.items) { count = len(queue.items) } @@ -170,9 +178,15 @@ func (oc *AIClient) getQueueSnapshot(roomID id.RoomID) *pendingQueue { if queue == nil { return nil } + queue.mu.Lock() + defer queue.mu.Unlock() clone := *queue clone.items = slices.Clone(queue.items) clone.summaryLines = slices.Clone(queue.summaryLines) + if queue.lastItem != nil { + lastItem := *queue.lastItem + clone.lastItem = &lastItem + } return &clone } @@ -186,6 +200,8 @@ func (oc *AIClient) roomHasPendingQueueWork(roomID id.RoomID) bool { if queue == nil { return false } + queue.mu.Lock() + defer queue.mu.Unlock() return queue.draining || len(queue.items) > 0 || queue.droppedCount > 0 } @@ -196,6 +212,8 @@ func (oc *AIClient) consumeQueueSummary(roomID id.RoomID, noun string) string { if queue == nil || queue.droppedCount == 0 { return "" } + queue.mu.Lock() + defer queue.mu.Unlock() summary := buildQueueSummaryPrompt(queue, noun) queue.droppedCount = 0 queue.summaryLines = nil @@ -358,6 +376,11 @@ func (oc *AIClient) markQueueDraining(roomID id.RoomID) bool { if queue == nil || queue.draining { return false } + queue.mu.Lock() + defer queue.mu.Unlock() + if queue.draining { + return false + } queue.draining = true return true } @@ -369,6 +392,8 @@ func (oc *AIClient) clearQueueDraining(roomID id.RoomID) { if queue == nil { return } + queue.mu.Lock() + defer queue.mu.Unlock() queue.draining = false if len(queue.items) == 0 && queue.droppedCount == 0 { delete(oc.pendingQueues, roomID) diff --git a/bridges/ai/prompt_projection_local.go b/bridges/ai/prompt_projection_local.go index 3e03119c..8dee4c39 100644 --- a/bridges/ai/prompt_projection_local.go +++ b/bridges/ai/prompt_projection_local.go @@ -146,7 +146,35 @@ func normalizePromptTurnPartType(partType string) string { } func canonicalPromptToolArguments(raw any) string { + switch typed := raw.(type) { + case nil: + return "{}" + case string: + trimmed := strings.TrimSpace(typed) + if trimmed == "" { + return "{}" + } + var decoded any + if err := json.Unmarshal([]byte(trimmed), &decoded); err == nil { + data, marshalErr := json.Marshal(decoded) + if marshalErr == nil && string(data) != "null" { + return string(data) + } + } + data, err := json.Marshal(typed) + if err == nil && string(data) != "null" { + return string(data) + } + default: + if data, err := json.Marshal(typed); err == nil && string(data) != "null" { + return string(data) + } + } if value := strings.TrimSpace(formatPromptCanonicalValue(raw)); value != "" { + data, err := json.Marshal(value) + if err == nil && string(data) != "null" { + return string(data) + } return value } return "{}" diff --git a/bridges/ai/prompt_projection_local_test.go b/bridges/ai/prompt_projection_local_test.go new file mode 100644 index 00000000..c7676e14 --- /dev/null +++ b/bridges/ai/prompt_projection_local_test.go @@ -0,0 +1,15 @@ +package ai + +import "testing" + +func TestCanonicalPromptToolArgumentsJSONEncodesPlainStrings(t *testing.T) { + if got := canonicalPromptToolArguments("hello"); got != `"hello"` { + t.Fatalf("expected plain string to be JSON-encoded, got %q", got) + } +} + +func TestCanonicalPromptToolArgumentsPreservesJSONStrings(t *testing.T) { + if got := canonicalPromptToolArguments(`{"query":"matrix"}`); got != `{"query":"matrix"}` { + t.Fatalf("expected JSON string to stay canonical JSON, got %q", got) + } +}