diff --git a/bridges/ai/agent_loop_request_builders.go b/bridges/ai/agent_loop_request_builders.go index 3a84b747..3ad189b7 100644 --- a/bridges/ai/agent_loop_request_builders.go +++ b/bridges/ai/agent_loop_request_builders.go @@ -2,6 +2,7 @@ package ai import ( "context" + "strings" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/packages/param" @@ -16,7 +17,6 @@ type agentLoopRequestSettings struct { model string maxTokens int temperature *float64 - systemPrompt string reasoningEffort string } @@ -25,7 +25,6 @@ func (oc *AIClient) buildAgentLoopRequestSettings(meta *PortalMetadata) agentLoo model: oc.effectiveModelForAPI(meta), maxTokens: oc.effectiveMaxTokens(meta), temperature: oc.effectiveTemperature(meta), - systemPrompt: oc.effectivePrompt(meta), reasoningEffort: oc.effectiveReasoningEffort(meta), } } @@ -101,6 +100,7 @@ func (oc *AIClient) buildChatCompletionsAgentLoopParams( func (oc *AIClient) buildResponsesAgentLoopParams( ctx context.Context, meta *PortalMetadata, + systemPrompt string, input responses.ResponseInputParam, allowResolvedBossAgent bool, ) responses.ResponseNewParams { @@ -119,8 +119,8 @@ func (oc *AIClient) buildResponsesAgentLoopParams( if settings.temperature != nil { params.Temperature = openai.Float(*settings.temperature) } - if settings.systemPrompt != "" { - params.Instructions = openai.String(settings.systemPrompt) + if trimmed := strings.TrimSpace(systemPrompt); trimmed != "" { + params.Instructions = openai.String(trimmed) } if effort, ok := reasoningEffortMap[settings.reasoningEffort]; ok { params.Reasoning = shared.ReasoningParam{ diff --git a/bridges/ai/agent_loop_request_builders_test.go b/bridges/ai/agent_loop_request_builders_test.go index 41caf0aa..20e6de3a 100644 --- a/bridges/ai/agent_loop_request_builders_test.go +++ b/bridges/ai/agent_loop_request_builders_test.go @@ -37,7 +37,7 @@ func TestAgentLoopRequestBuildersShareModelAndTokenSettings(t *testing.T) { chatParams := oc.buildChatCompletionsAgentLoopParams(context.Background(), meta, []openai.ChatCompletionMessageParamUnion{ openai.UserMessage("hello"), }) - responsesParams := oc.buildResponsesAgentLoopParams(context.Background(), meta, nil, false) + responsesParams := oc.buildResponsesAgentLoopParams(context.Background(), meta, "system prompt", nil, false) if chatParams.Model != "openai/gpt-5.2" { t.Fatalf("expected chat model openai/gpt-5.2, got %q", chatParams.Model) @@ -96,7 +96,7 @@ func TestAgentLoopRequestBuildersPreserveExplicitZeroTemperature(t *testing.T) { chatParams := oc.buildChatCompletionsAgentLoopParams(context.Background(), meta, []openai.ChatCompletionMessageParamUnion{ openai.UserMessage("hello"), }) - responsesParams := oc.buildResponsesAgentLoopParams(context.Background(), meta, nil, false) + responsesParams := oc.buildResponsesAgentLoopParams(context.Background(), meta, "system prompt", nil, false) if !chatParams.Temperature.Valid() || chatParams.Temperature.Value != 0 { t.Fatalf("expected explicit zero chat temperature, got %#v", chatParams.Temperature) diff --git a/bridges/ai/agent_loop_routing_test.go b/bridges/ai/agent_loop_routing_test.go index 2b92eda0..57bd284a 100644 --- a/bridges/ai/agent_loop_routing_test.go +++ b/bridges/ai/agent_loop_routing_test.go @@ -7,8 +7,6 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - - bridgesdk "github.com/beeper/agentremote/sdk" ) func newAgentLoopRoutingTestClient(models ...ModelInfo) *AIClient { @@ -45,13 +43,7 @@ func TestSelectAgentLoopRunFunc_UsesChatCompletionsForUnsupportedResponsesPrompt API: string(ModelAPIResponses), }) - promptContext := PromptContext{ - PromptContext: bridgesdk.UserPromptContext(bridgesdk.PromptBlock{ - Type: bridgesdk.PromptBlockAudio, - AudioB64: "YXVkaW8=", - AudioFormat: "mp3", - }), - } + promptContext := UserPromptContext(PromptBlock{Type: PromptBlockType("unknown")}) responseFn, logLabel := oc.selectAgentLoopRunFunc(resolvedModelMeta("openai/gpt-4.1"), promptContext) if responseFn == nil { diff --git a/bridges/ai/agent_loop_steering_test.go b/bridges/ai/agent_loop_steering_test.go index 1a5330b5..e2036588 100644 --- a/bridges/ai/agent_loop_steering_test.go +++ b/bridges/ai/agent_loop_steering_test.go @@ -9,6 +9,29 @@ import ( airuntime "github.com/beeper/agentremote/pkg/runtime" ) +func getFollowUpMessagesForTest(oc *AIClient, roomID id.RoomID) []PromptMessage { + if oc == nil || roomID == "" { + return nil + } + snapshot := oc.getQueueSnapshot(roomID) + if snapshot == nil { + return nil + } + behavior := airuntime.ResolveQueueBehavior(snapshot.mode) + if !behavior.Followup { + return nil + } + candidate, _ := oc.takePendingQueueDispatchCandidate(roomID, true) + if candidate == nil || len(candidate.items) == 0 { + return nil + } + _, prompt, ok := preparePendingQueueDispatchCandidate(candidate) + if !ok { + return nil + } + return buildSteeringPromptMessages([]string{prompt}) +} + func TestGetSteeringMessages_FiltersAndDrainsQueue(t *testing.T) { roomID := id.RoomID("!room:example.com") oc := &AIClient{ @@ -53,15 +76,15 @@ func TestGetSteeringMessages_FiltersAndDrainsQueue(t *testing.T) { } func TestBuildSteeringUserMessages(t *testing.T) { - got := buildSteeringUserMessages([]string{"first", " ", "second"}) + got := buildSteeringPromptMessages([]string{"first", " ", "second"}) if len(got) != 2 { - t.Fatalf("expected 2 steering user messages, got %d", len(got)) + t.Fatalf("expected 2 steering prompt messages, got %d", len(got)) } - if got[0].OfUser == nil || got[0].OfUser.Content.OfString.Value != "first" { - t.Fatalf("unexpected first steering user message: %#v", got[0]) + if got[0].Role != PromptRoleUser || got[0].Text() != "first" { + t.Fatalf("unexpected first steering prompt message: %#v", got[0]) } - if got[1].OfUser == nil || got[1].OfUser.Content.OfString.Value != "second" { - t.Fatalf("unexpected second steering user message: %#v", got[1]) + if got[1].Role != PromptRoleUser || got[1].Text() != "second" { + t.Fatalf("unexpected second steering prompt message: %#v", got[1]) } } @@ -78,8 +101,8 @@ func TestGetFollowUpMessages_ConsumesSingleQueuedTextMessage(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) - if len(messages) != 1 || messages[0].OfUser == nil || messages[0].OfUser.Content.OfString.Value != "follow up" { + messages := getFollowUpMessagesForTest(oc, roomID) + if len(messages) != 1 || messages[0].Role != PromptRoleUser || messages[0].Text() != "follow up" { t.Fatalf("unexpected follow-up messages: %#v", messages) } if snapshot := oc.getQueueSnapshot(roomID); snapshot != nil { @@ -101,12 +124,12 @@ func TestGetFollowUpMessages_CollectsQueuedTextMessages(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) - if len(messages) != 1 || messages[0].OfUser == nil { + messages := getFollowUpMessagesForTest(oc, roomID) + if len(messages) != 1 || messages[0].Role != PromptRoleUser { t.Fatalf("expected one combined follow-up message, got %#v", messages) } - if messages[0].OfUser.Content.OfString.Value != "[Queued messages while agent was busy]\n\n---\nQueued #1\nfirst\n\n---\nQueued #2\nsecond" { - t.Fatalf("unexpected combined follow-up prompt: %q", messages[0].OfUser.Content.OfString.Value) + if messages[0].Text() != "[Queued messages while agent was busy]\n\n---\nQueued #1\nfirst\n\n---\nQueued #2\nsecond" { + t.Fatalf("unexpected combined follow-up prompt: %q", messages[0].Text()) } } @@ -127,15 +150,15 @@ func TestGetFollowUpMessages_CollectSummaryIsConsumed(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) - if len(messages) != 1 || messages[0].OfUser == nil { + messages := getFollowUpMessagesForTest(oc, roomID) + if len(messages) != 1 || messages[0].Role != PromptRoleUser { t.Fatalf("expected one combined follow-up message, got %#v", messages) } - if messages[0].OfUser.Content.OfString.Value != "[Queued messages while agent was busy]\n\n[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two\n\n---\nQueued #1\nfirst\n\n---\nQueued #2\nsecond" { - t.Fatalf("unexpected combined follow-up prompt with summary: %q", messages[0].OfUser.Content.OfString.Value) + if messages[0].Text() != "[Queued messages while agent was busy]\n\n[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two\n\n---\nQueued #1\nfirst\n\n---\nQueued #2\nsecond" { + t.Fatalf("unexpected combined follow-up prompt with summary: %q", messages[0].Text()) } - if again := oc.getFollowUpMessages(roomID); len(again) != 0 { + if again := getFollowUpMessagesForTest(oc, roomID); len(again) != 0 { t.Fatalf("expected collect summary to be consumed, got %#v", again) } if snapshot := oc.getQueueSnapshot(roomID); snapshot != nil { @@ -159,12 +182,12 @@ func TestGetFollowUpMessages_UsesSyntheticSummaryPrompt(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) - if len(messages) != 1 || messages[0].OfUser == nil { + messages := getFollowUpMessagesForTest(oc, roomID) + if len(messages) != 1 || messages[0].Role != PromptRoleUser { t.Fatalf("expected one synthetic follow-up message, got %#v", messages) } - if messages[0].OfUser.Content.OfString.Value != "[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two" { - t.Fatalf("unexpected synthetic follow-up prompt: %q", messages[0].OfUser.Content.OfString.Value) + if messages[0].Text() != "[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two" { + t.Fatalf("unexpected synthetic follow-up prompt: %q", messages[0].Text()) } } @@ -184,23 +207,23 @@ func TestGetFollowUpMessages_SyntheticSummaryIsConsumedBeforeLatestMessage(t *te }, } - first := oc.getFollowUpMessages(roomID) - if len(first) != 1 || first[0].OfUser == nil { + first := getFollowUpMessagesForTest(oc, roomID) + if len(first) != 1 || first[0].Role != PromptRoleUser { t.Fatalf("expected one synthetic follow-up message, got %#v", first) } - if first[0].OfUser.Content.OfString.Value != "[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two" { - t.Fatalf("unexpected first synthetic follow-up prompt: %q", first[0].OfUser.Content.OfString.Value) + if first[0].Text() != "[Queue overflow] Dropped 2 messages due to cap.\nSummary:\n- older one\n- older two" { + t.Fatalf("unexpected first synthetic follow-up prompt: %q", first[0].Text()) } - second := oc.getFollowUpMessages(roomID) - if len(second) != 1 || second[0].OfUser == nil { + second := getFollowUpMessagesForTest(oc, roomID) + if len(second) != 1 || second[0].Role != PromptRoleUser { t.Fatalf("expected queued latest message after summary, got %#v", second) } - if second[0].OfUser.Content.OfString.Value != "latest" { - t.Fatalf("expected latest queued message after consuming summary, got %q", second[0].OfUser.Content.OfString.Value) + if second[0].Text() != "latest" { + t.Fatalf("expected latest queued message after consuming summary, got %q", second[0].Text()) } - if third := oc.getFollowUpMessages(roomID); len(third) != 0 { + if third := getFollowUpMessagesForTest(oc, roomID); len(third) != 0 { t.Fatalf("expected queue to be drained after latest message, got %#v", third) } } @@ -218,7 +241,7 @@ func TestGetFollowUpMessages_LeavesNonTextQueueItemsForBacklogProcessing(t *test }, } - messages := oc.getFollowUpMessages(roomID) + messages := getFollowUpMessagesForTest(oc, roomID) if len(messages) != 0 { t.Fatalf("expected non-text follow-up to stay queued, got %#v", messages) } @@ -240,7 +263,7 @@ func TestGetFollowUpMessages_LeavesNonFollowupQueueUntouched(t *testing.T) { }, } - messages := oc.getFollowUpMessages(roomID) + messages := getFollowUpMessagesForTest(oc, roomID) if len(messages) != 0 { t.Fatalf("expected no follow-up messages for non-followup mode, got %#v", messages) } @@ -251,30 +274,54 @@ func TestGetFollowUpMessages_LeavesNonFollowupQueueUntouched(t *testing.T) { func TestBuildContinuationParams_UsesPendingSteeringPromptsBeforeDrainingQueue(t *testing.T) { roomID := id.RoomID("!room:example.com") - oc := &AIClient{ - connector: &OpenAIConnector{}, - activeRoomRuns: map[id.RoomID]*roomRunState{ - roomID: { - steerQueue: []pendingQueueItem{ - {pending: pendingMessage{Type: pendingTypeText, MessageBody: "queue steer"}}, + newClient := func() *AIClient { + return &AIClient{ + connector: &OpenAIConnector{}, + activeRoomRuns: map[id.RoomID]*roomRunState{ + roomID: { + steerQueue: []pendingQueueItem{ + {pending: pendingMessage{Type: pendingTypeText, MessageBody: "queue steer"}}, + }, }, }, - }, + } } - state := &streamingState{roomID: roomID} - state.addPendingSteeringPrompts([]string{"pending steer"}) - params := oc.buildContinuationParams(context.Background(), state, nil, nil, nil) - if len(params.Input.OfInputItemList) == 0 { - t.Fatal("expected continuation input to include stored steering prompt") - } - if pending := state.consumePendingSteeringPrompts(); len(pending) != 0 { - t.Fatalf("expected pending steering prompts to be consumed, got %#v", pending) - } - if len(state.baseInput) == 0 { - t.Fatal("expected steering input to persist in base input even when history starts empty") - } - if snapshot := oc.getRoomRun(roomID); snapshot == nil || len(snapshot.steerQueue) != 1 { - t.Fatalf("expected queued steering item to remain available, got %#v", snapshot) - } + t.Run("non-nil prompt", func(t *testing.T) { + oc := newClient() + state := &streamingState{roomID: roomID} + state.addPendingSteeringPrompts([]string{"pending steer"}) + prompt := PromptContext{} + + params := oc.buildContinuationParams(context.Background(), &prompt, state, nil, nil, nil) + if len(params.Input.OfInputItemList) == 0 { + t.Fatal("expected continuation input to include stored steering prompt") + } + if pending := state.consumePendingSteeringPrompts(); len(pending) != 0 { + t.Fatalf("expected pending steering prompts to be consumed, got %#v", pending) + } + if len(prompt.Messages) == 0 { + t.Fatal("expected steering input to persist in canonical prompt even when history starts empty") + } + if snapshot := oc.getRoomRun(roomID); snapshot == nil || len(snapshot.steerQueue) != 1 { + t.Fatalf("expected queued steering item to remain available, got %#v", snapshot) + } + }) + + t.Run("nil prompt", func(t *testing.T) { + oc := newClient() + state := &streamingState{roomID: roomID} + state.addPendingSteeringPrompts([]string{"pending steer"}) + + params := oc.buildContinuationParams(context.Background(), nil, state, nil, nil, nil) + if len(params.Input.OfInputItemList) == 0 { + t.Fatal("expected continuation input to include stored steering prompt") + } + if pending := state.consumePendingSteeringPrompts(); len(pending) != 0 { + t.Fatalf("expected pending steering prompts to be consumed, got %#v", pending) + } + if snapshot := oc.getRoomRun(roomID); snapshot == nil || len(snapshot.steerQueue) != 1 { + t.Fatalf("expected queued steering item to remain available, got %#v", snapshot) + } + }) } diff --git a/bridges/ai/agent_loop_test.go b/bridges/ai/agent_loop_test.go index 730b8fdd..cd1e6412 100644 --- a/bridges/ai/agent_loop_test.go +++ b/bridges/ai/agent_loop_test.go @@ -5,16 +5,13 @@ import ( "errors" "testing" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/event" ) type fakeAgentLoopProvider struct { track bool results []fakeAgentLoopResult - followUps map[int][]openai.ChatCompletionMessageParamUnion finalizeCalls int - continueCalls int roundsObserved []int } @@ -41,19 +38,6 @@ func (f *fakeAgentLoopProvider) FinalizeAgentLoop(context.Context) { f.finalizeCalls++ } -func (f *fakeAgentLoopProvider) GetFollowUpMessages(_ context.Context) []openai.ChatCompletionMessageParamUnion { - if len(f.roundsObserved) == 0 { - return nil - } - return f.followUps[f.roundsObserved[len(f.roundsObserved)-1]] -} - -func (f *fakeAgentLoopProvider) ContinueAgentLoop(messages []openai.ChatCompletionMessageParamUnion) { - if len(messages) > 0 { - f.continueCalls++ - } -} - func TestExecuteAgentLoopRoundsFinalizesOnTerminalTurn(t *testing.T) { provider := &fakeAgentLoopProvider{ results: []fakeAgentLoopResult{ @@ -126,14 +110,10 @@ func TestExecuteAgentLoopRoundsStopsOnContextLengthWithFinalize(t *testing.T) { } } -func TestExecuteAgentLoopRoundsContinuesForFollowUpMessages(t *testing.T) { +func TestExecuteAgentLoopRoundsDoesNotInlineFollowUpMessages(t *testing.T) { provider := &fakeAgentLoopProvider{ results: []fakeAgentLoopResult{ {continueLoop: false}, - {continueLoop: false}, - }, - followUps: map[int][]openai.ChatCompletionMessageParamUnion{ - 0: {openai.UserMessage("follow up")}, }, } @@ -147,13 +127,10 @@ func TestExecuteAgentLoopRoundsContinuesForFollowUpMessages(t *testing.T) { if err != nil { t.Fatalf("expected no error, got %v", err) } - if provider.continueCalls != 1 { - t.Fatalf("expected one follow-up continuation, got %d", provider.continueCalls) - } if provider.finalizeCalls != 1 { t.Fatalf("expected finalize once, got %d", provider.finalizeCalls) } - if len(provider.roundsObserved) != 2 || provider.roundsObserved[0] != 0 || provider.roundsObserved[1] != 1 { + if len(provider.roundsObserved) != 1 || provider.roundsObserved[0] != 0 { t.Fatalf("unexpected rounds observed: %#v", provider.roundsObserved) } } diff --git a/bridges/ai/beeper_models_generated.go b/bridges/ai/beeper_models_generated.go index 10f9f707..ecb15657 100644 --- a/bridges/ai/beeper_models_generated.go +++ b/bridges/ai/beeper_models_generated.go @@ -1,5 +1,5 @@ // Code generated by generate-models. DO NOT EDIT. -// Generated at: 2026-03-08T11:58:59Z +// Generated at: 2026-03-30T00:58:12Z package ai @@ -27,40 +27,6 @@ var ModelManifest = struct { MaxOutputTokens: 64000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "anthropic/claude-opus-4.1": { - ID: "anthropic/claude-opus-4.1", - Name: "Claude 4.1 Opus", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 32000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "anthropic/claude-opus-4.5": { - ID: "anthropic/claude-opus-4.5", - Name: "Claude Opus 4.5", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 64000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, "anthropic/claude-opus-4.6": { ID: "anthropic/claude-opus-4.6", Name: "Claude Opus 4.6", @@ -78,40 +44,6 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "anthropic/claude-sonnet-4": { - ID: "anthropic/claude-sonnet-4", - Name: "Claude 4 Sonnet", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 64000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "anthropic/claude-sonnet-4.5": { - ID: "anthropic/claude-sonnet-4.5", - Name: "Claude Sonnet 4.5", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 1000000, - MaxOutputTokens: 64000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, "anthropic/claude-sonnet-4.6": { ID: "anthropic/claude-sonnet-4.6", Name: "Claude Sonnet 4.6", @@ -129,57 +61,6 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "deepseek/deepseek-chat-v3-0324": { - ID: "deepseek/deepseek-chat-v3-0324", - Name: "DeepSeek v3 (0324)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 163840, - MaxOutputTokens: 163840, - AvailableTools: []string{ToolFunctionCalling}, - }, - "deepseek/deepseek-chat-v3.1": { - ID: "deepseek/deepseek-chat-v3.1", - Name: "DeepSeek v3.1", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 32768, - MaxOutputTokens: 7168, - AvailableTools: []string{ToolFunctionCalling}, - }, - "deepseek/deepseek-r1": { - ID: "deepseek/deepseek-r1", - Name: "DeepSeek R1 (Original)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 64000, - MaxOutputTokens: 16000, - AvailableTools: []string{ToolFunctionCalling}, - }, "deepseek/deepseek-r1-0528": { ID: "deepseek/deepseek-r1-0528", Name: "DeepSeek R1 (0528)", @@ -197,40 +78,6 @@ var ModelManifest = struct { MaxOutputTokens: 65536, AvailableTools: []string{ToolFunctionCalling}, }, - "deepseek/deepseek-r1-distill-qwen-32b": { - ID: "deepseek/deepseek-r1-distill-qwen-32b", - Name: "DeepSeek R1 (Qwen Distilled)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: false, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 32768, - MaxOutputTokens: 32768, - AvailableTools: []string{}, - }, - "deepseek/deepseek-v3.1-terminus": { - ID: "deepseek/deepseek-v3.1-terminus", - Name: "DeepSeek v3.1 Terminus", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 163840, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, "deepseek/deepseek-v3.2": { ID: "deepseek/deepseek-v3.2", Name: "DeepSeek v3.2", @@ -245,77 +92,9 @@ var ModelManifest = struct { SupportsVideo: false, SupportsPDF: false, ContextWindow: 163840, - MaxOutputTokens: 65536, - AvailableTools: []string{ToolFunctionCalling}, - }, - "google/gemini-2.0-flash-001": { - ID: "google/gemini-2.0-flash-001", - Name: "Gemini 2.0 Flash", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 8192, - AvailableTools: []string{ToolFunctionCalling}, - }, - "google/gemini-2.0-flash-lite-001": { - ID: "google/gemini-2.0-flash-lite-001", - Name: "Gemini 2.0 Flash Lite", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 8192, - AvailableTools: []string{ToolFunctionCalling}, - }, - "google/gemini-2.5-flash": { - ID: "google/gemini-2.5-flash", - Name: "Gemini 2.5 Flash", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 65535, + MaxOutputTokens: 0, AvailableTools: []string{ToolFunctionCalling}, }, - "google/gemini-2.5-flash-image": { - ID: "google/gemini-2.5-flash-image", - Name: "Nano Banana", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: false, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: true, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 32768, - MaxOutputTokens: 32768, - AvailableTools: []string{}, - }, "google/gemini-2.5-flash-lite": { ID: "google/gemini-2.5-flash-lite", Name: "Gemini 2.5 Flash Lite", @@ -367,73 +146,22 @@ var ModelManifest = struct { MaxOutputTokens: 65536, AvailableTools: []string{ToolFunctionCalling}, }, - "google/gemini-3-pro-image-preview": { - ID: "google/gemini-3-pro-image-preview", - Name: "Nano Banana Pro", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: false, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: true, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 65536, - MaxOutputTokens: 32768, - AvailableTools: []string{}, - }, - "google/gemini-3.1-flash-lite-preview": { - ID: "google/gemini-3.1-flash-lite-preview", - Name: "Gemini 3.1 Flash Lite", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 65536, - AvailableTools: []string{ToolFunctionCalling}, - }, - "google/gemini-3.1-pro-preview": { - ID: "google/gemini-3.1-pro-preview", - Name: "Gemini 3.1 Pro", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: true, - SupportsVideo: true, - SupportsPDF: true, - ContextWindow: 1048576, - MaxOutputTokens: 65536, - AvailableTools: []string{ToolFunctionCalling}, - }, - "meta-llama/llama-3.3-70b-instruct": { - ID: "meta-llama/llama-3.3-70b-instruct", - Name: "Llama 3.3 70B", + "google/gemma-2-27b-it": { + ID: "google/gemma-2-27b-it", + Name: "Gemma 2 27B", Provider: "openrouter", API: "openai-completions", SupportsVision: false, - SupportsToolCalling: true, + SupportsToolCalling: false, SupportsReasoning: false, SupportsWebSearch: false, SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolFunctionCalling}, + ContextWindow: 8192, + MaxOutputTokens: 2048, + AvailableTools: []string{}, }, "meta-llama/llama-4-maverick": { ID: "meta-llama/llama-4-maverick", @@ -452,283 +180,79 @@ var ModelManifest = struct { MaxOutputTokens: 16384, AvailableTools: []string{ToolFunctionCalling}, }, - "meta-llama/llama-4-scout": { - ID: "meta-llama/llama-4-scout", - Name: "Llama 4 Scout", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 327680, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolFunctionCalling}, - }, - "minimax/minimax-m2": { - ID: "minimax/minimax-m2", - Name: "MiniMax M2", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 196608, - MaxOutputTokens: 196608, - AvailableTools: []string{ToolFunctionCalling}, - }, - "minimax/minimax-m2.1": { - ID: "minimax/minimax-m2.1", - Name: "MiniMax M2.1", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 196608, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "minimax/minimax-m2.5": { - ID: "minimax/minimax-m2.5", - Name: "MiniMax M2.5", + "minimax/minimax-m2.7": { + ID: "minimax/minimax-m2.7", + Name: "MiniMax M2.7", Provider: "openrouter", API: "openai-completions", SupportsVision: false, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 196608, - MaxOutputTokens: 196608, - AvailableTools: []string{ToolFunctionCalling}, - }, - "moonshotai/kimi-k2": { - ID: "moonshotai/kimi-k2", - Name: "Kimi K2 (0711)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131000, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "moonshotai/kimi-k2-0905": { - ID: "moonshotai/kimi-k2-0905", - Name: "Kimi K2 (0905)", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "moonshotai/kimi-k2.5": { - ID: "moonshotai/kimi-k2.5", - Name: "Kimi K2.5", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 262144, - MaxOutputTokens: 65535, - AvailableTools: []string{ToolFunctionCalling}, - }, - "openai/gpt-4.1": { - ID: "openai/gpt-4.1", - Name: "GPT-4.1", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 1047576, - MaxOutputTokens: 32768, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-4.1-mini": { - ID: "openai/gpt-4.1-mini", - Name: "GPT-4.1 Mini", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 1047576, - MaxOutputTokens: 32768, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-4.1-nano": { - ID: "openai/gpt-4.1-nano", - Name: "GPT-4.1 Nano", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 1047576, - MaxOutputTokens: 32768, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-4o-mini": { - ID: "openai/gpt-4o-mini", - Name: "GPT-4o-mini", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 128000, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-5": { - ID: "openai/gpt-5", - Name: "GPT-5", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "openai/gpt-5-image": { - ID: "openai/gpt-5-image", - Name: "GPT ImageGen 1.5", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: true, + SupportsWebSearch: false, + SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + SupportsPDF: false, + ContextWindow: 204800, + MaxOutputTokens: 131072, + AvailableTools: []string{ToolFunctionCalling}, }, - "openai/gpt-5-image-mini": { - ID: "openai/gpt-5-image-mini", - Name: "GPT ImageGen", + "mistralai/devstral-2512": { + ID: "mistralai/devstral-2512", + Name: "Devstral 2", Provider: "openrouter", - API: "responses", - SupportsVision: true, + API: "openai-completions", + SupportsVision: false, SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: true, + SupportsReasoning: false, + SupportsWebSearch: false, + SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + SupportsPDF: false, + ContextWindow: 262144, + MaxOutputTokens: 0, + AvailableTools: []string{ToolFunctionCalling}, }, - "openai/gpt-5-mini": { - ID: "openai/gpt-5-mini", - Name: "GPT-5 mini", + "mistralai/mistral-small-2603": { + ID: "mistralai/mistral-small-2603", + Name: "Mistral Small 4", Provider: "openrouter", - API: "responses", + API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: true, + SupportsWebSearch: false, SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + SupportsPDF: false, + ContextWindow: 262144, + MaxOutputTokens: 0, + AvailableTools: []string{ToolFunctionCalling}, }, - "openai/gpt-5-nano": { - ID: "openai/gpt-5-nano", - Name: "GPT-5 nano", + "moonshotai/kimi-k2.5": { + ID: "moonshotai/kimi-k2.5", + Name: "Kimi K2.5", Provider: "openrouter", - API: "responses", + API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: true, + SupportsWebSearch: false, SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 400000, - MaxOutputTokens: 128000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + SupportsPDF: false, + ContextWindow: 262144, + MaxOutputTokens: 65535, + AvailableTools: []string{ToolFunctionCalling}, }, - "openai/gpt-5.1": { - ID: "openai/gpt-5.1", - Name: "GPT-5.1", + "openai/gpt-5-mini": { + ID: "openai/gpt-5-mini", + Name: "GPT-5 mini", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -745,7 +269,7 @@ var ModelManifest = struct { ID: "openai/gpt-5.2", Name: "GPT-5.2", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -758,11 +282,11 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "openai/gpt-5.2-pro": { - ID: "openai/gpt-5.2-pro", - Name: "GPT-5.2 Pro", + "openai/gpt-5.3-codex": { + ID: "openai/gpt-5.3-codex", + Name: "GPT-5.3 Codex", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -775,28 +299,11 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "openai/gpt-5.3-chat": { - ID: "openai/gpt-5.3-chat", - Name: "GPT-5.3 Instant", - Provider: "openrouter", - API: "responses", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 128000, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, "openai/gpt-5.4": { ID: "openai/gpt-5.4", Name: "GPT-5.4", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -809,45 +316,11 @@ var ModelManifest = struct { MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "openai/gpt-oss-120b": { - ID: "openai/gpt-oss-120b", - Name: "GPT OSS 120B", - Provider: "openrouter", - API: "responses", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "openai/gpt-oss-20b": { - ID: "openai/gpt-oss-20b", - Name: "GPT OSS 20B", - Provider: "openrouter", - API: "responses", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, - }, - "openai/o3": { - ID: "openai/o3", - Name: "o3", + "openai/gpt-5.4-mini": { + ID: "openai/gpt-5.4-mini", + Name: "GPT-5.4 Mini", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -856,32 +329,15 @@ var ModelManifest = struct { SupportsAudio: false, SupportsVideo: false, SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 100000, + ContextWindow: 400000, + MaxOutputTokens: 128000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "openai/o3-mini": { - ID: "openai/o3-mini", - Name: "o3-mini", - Provider: "openrouter", - API: "responses", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: true, - ContextWindow: 200000, - MaxOutputTokens: 100000, - AvailableTools: []string{ToolFunctionCalling}, - }, - "openai/o3-pro": { - ID: "openai/o3-pro", - Name: "o3 Pro", + "openai/o3": { + ID: "openai/o3", + Name: "o3", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -898,7 +354,7 @@ var ModelManifest = struct { ID: "openai/o4-mini", Name: "o4-mini", Provider: "openrouter", - API: "responses", + API: "openai-responses", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, @@ -928,43 +384,9 @@ var ModelManifest = struct { MaxOutputTokens: 0, AvailableTools: []string{}, }, - "qwen/qwen3-235b-a22b": { - ID: "qwen/qwen3-235b-a22b", - Name: "Qwen 3 235B", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 8192, - AvailableTools: []string{ToolFunctionCalling}, - }, - "qwen/qwen3-32b": { - ID: "qwen/qwen3-32b", - Name: "Qwen 3 32B", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 40960, - MaxOutputTokens: 40960, - AvailableTools: []string{ToolFunctionCalling}, - }, - "qwen/qwen3-coder": { - ID: "qwen/qwen3-coder", - Name: "Qwen 3 Coder", + "qwen/qwen3-coder-next": { + ID: "qwen/qwen3-coder-next", + Name: "Qwen 3 Coder Next", Provider: "openrouter", API: "openai-completions", SupportsVision: false, @@ -976,76 +398,42 @@ var ModelManifest = struct { SupportsVideo: false, SupportsPDF: false, ContextWindow: 262144, - MaxOutputTokens: 0, + MaxOutputTokens: 65536, AvailableTools: []string{ToolFunctionCalling}, }, - "x-ai/grok-3": { - ID: "x-ai/grok-3", - Name: "Grok 3", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: false, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "x-ai/grok-3-mini": { - ID: "x-ai/grok-3-mini", - Name: "Grok 3 Mini", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: true, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 0, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, - }, - "x-ai/grok-4": { - ID: "x-ai/grok-4", - Name: "Grok 4", + "qwen/qwen3.5-flash-02-23": { + ID: "qwen/qwen3.5-flash-02-23", + Name: "Qwen 3.5 Flash", Provider: "openrouter", API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: true, + SupportsWebSearch: false, SupportsImageGen: false, SupportsAudio: false, - SupportsVideo: false, + SupportsVideo: true, SupportsPDF: false, - ContextWindow: 256000, - MaxOutputTokens: 0, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + ContextWindow: 1000000, + MaxOutputTokens: 65536, + AvailableTools: []string{ToolFunctionCalling}, }, - "x-ai/grok-4-fast": { - ID: "x-ai/grok-4-fast", - Name: "Grok 4 Fast", + "qwen/qwen3.5-plus-02-15": { + ID: "qwen/qwen3.5-plus-02-15", + Name: "Qwen 3.5 Plus", Provider: "openrouter", API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: true, + SupportsWebSearch: false, SupportsImageGen: false, SupportsAudio: false, - SupportsVideo: false, + SupportsVideo: true, SupportsPDF: false, - ContextWindow: 2000000, - MaxOutputTokens: 30000, - AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, + ContextWindow: 1000000, + MaxOutputTokens: 65536, + AvailableTools: []string{ToolFunctionCalling}, }, "x-ai/grok-4.1-fast": { ID: "x-ai/grok-4.1-fast", @@ -1064,111 +452,43 @@ var ModelManifest = struct { MaxOutputTokens: 30000, AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "z-ai/glm-4.5": { - ID: "z-ai/glm-4.5", - Name: "GLM 4.5", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 98304, - AvailableTools: []string{ToolFunctionCalling}, - }, - "z-ai/glm-4.5-air": { - ID: "z-ai/glm-4.5-air", - Name: "GLM 4.5 Air", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 98304, - AvailableTools: []string{ToolFunctionCalling}, - }, - "z-ai/glm-4.5v": { - ID: "z-ai/glm-4.5v", - Name: "GLM 4.5V", + "x-ai/grok-4.20-beta": { + ID: "x-ai/grok-4.20-beta", + Name: "Grok 4.20 Beta", Provider: "openrouter", API: "openai-completions", SupportsVision: true, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: false, - SupportsPDF: false, - ContextWindow: 65536, - MaxOutputTokens: 16384, - AvailableTools: []string{ToolFunctionCalling}, - }, - "z-ai/glm-4.6": { - ID: "z-ai/glm-4.6", - Name: "GLM 4.6", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: false, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, + SupportsWebSearch: true, SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, SupportsPDF: false, - ContextWindow: 204800, - MaxOutputTokens: 204800, - AvailableTools: []string{ToolFunctionCalling}, - }, - "z-ai/glm-4.6v": { - ID: "z-ai/glm-4.6v", - Name: "GLM 4.6V", - Provider: "openrouter", - API: "openai-completions", - SupportsVision: true, - SupportsToolCalling: true, - SupportsReasoning: true, - SupportsWebSearch: false, - SupportsImageGen: false, - SupportsAudio: false, - SupportsVideo: true, - SupportsPDF: false, - ContextWindow: 131072, - MaxOutputTokens: 131072, - AvailableTools: []string{ToolFunctionCalling}, + ContextWindow: 2000000, + MaxOutputTokens: 0, + AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "z-ai/glm-4.7": { - ID: "z-ai/glm-4.7", - Name: "GLM 4.7", + "x-ai/grok-code-fast-1": { + ID: "x-ai/grok-code-fast-1", + Name: "Grok Code Fast 1", Provider: "openrouter", API: "openai-completions", SupportsVision: false, SupportsToolCalling: true, SupportsReasoning: true, - SupportsWebSearch: false, + SupportsWebSearch: true, SupportsImageGen: false, SupportsAudio: false, SupportsVideo: false, SupportsPDF: false, - ContextWindow: 202752, - MaxOutputTokens: 0, - AvailableTools: []string{ToolFunctionCalling}, + ContextWindow: 256000, + MaxOutputTokens: 10000, + AvailableTools: []string{ToolWebSearch, ToolFunctionCalling}, }, - "z-ai/glm-5": { - ID: "z-ai/glm-5", - Name: "GLM 5", + "z-ai/glm-5-turbo": { + ID: "z-ai/glm-5-turbo", + Name: "GLM 5 Turbo", Provider: "openrouter", API: "openai-completions", SupportsVision: false, @@ -1180,14 +500,14 @@ var ModelManifest = struct { SupportsVideo: false, SupportsPDF: false, ContextWindow: 202752, - MaxOutputTokens: 0, + MaxOutputTokens: 131072, AvailableTools: []string{ToolFunctionCalling}, }, }, Aliases: map[string]string{ "beeper/default": "anthropic/claude-opus-4.6", - "beeper/fast": "openai/gpt-5-mini", - "beeper/reasoning": "openai/gpt-5.2", - "beeper/smart": "openai/gpt-5.2", + "beeper/fast": "openai/gpt-5.4-mini", + "beeper/reasoning": "openai/o3", + "beeper/smart": "openai/gpt-5.4", }, } diff --git a/bridges/ai/beeper_models_manifest_test.go b/bridges/ai/beeper_models_manifest_test.go index 47db642c..3abdeb6e 100644 --- a/bridges/ai/beeper_models_manifest_test.go +++ b/bridges/ai/beeper_models_manifest_test.go @@ -4,75 +4,35 @@ import "testing" func TestModelManifestMatchesOpenRouterAllowlist(t *testing.T) { expected := map[string]struct{}{ - "google/gemini-3.1-flash-lite-preview": {}, - "openai/gpt-5.3-chat": {}, - "openai/gpt-5.4": {}, - "qwen/qwen2.5-vl-32b-instruct": {}, - "qwen/qwen3-32b": {}, - "qwen/qwen3-235b-a22b": {}, - "qwen/qwen3-coder": {}, - "anthropic/claude-sonnet-4": {}, - "anthropic/claude-sonnet-4.5": {}, - "anthropic/claude-sonnet-4.6": {}, - "anthropic/claude-opus-4.1": {}, - "anthropic/claude-haiku-4.5": {}, - "anthropic/claude-opus-4.5": {}, - "anthropic/claude-opus-4.6": {}, - "deepseek/deepseek-chat-v3-0324": {}, - "deepseek/deepseek-chat-v3.1": {}, - "deepseek/deepseek-v3.1-terminus": {}, - "deepseek/deepseek-v3.2": {}, - "deepseek/deepseek-r1": {}, - "deepseek/deepseek-r1-0528": {}, - "deepseek/deepseek-r1-distill-qwen-32b": {}, - "google/gemini-2.0-flash-001": {}, - "google/gemini-2.5-flash": {}, - "google/gemini-2.5-flash-lite": {}, - "google/gemini-2.5-flash-image": {}, - "google/gemini-2.0-flash-lite-001": {}, - "google/gemini-2.5-pro": {}, - "google/gemini-3.1-pro-preview": {}, - "google/gemini-3-pro-image-preview": {}, - "google/gemini-3-flash-preview": {}, - "meta-llama/llama-3.3-70b-instruct": {}, - "meta-llama/llama-4-scout": {}, - "meta-llama/llama-4-maverick": {}, - "minimax/minimax-m2": {}, - "minimax/minimax-m2.1": {}, - "minimax/minimax-m2.5": {}, - "moonshotai/kimi-k2": {}, - "moonshotai/kimi-k2-0905": {}, - "moonshotai/kimi-k2.5": {}, - "openai/gpt-oss-20b": {}, - "openai/gpt-oss-120b": {}, - "openai/gpt-4o-mini": {}, - "openai/gpt-4.1": {}, - "openai/gpt-4.1-mini": {}, - "openai/gpt-4.1-nano": {}, - "openai/gpt-5": {}, - "openai/gpt-5-mini": {}, - "openai/gpt-5-nano": {}, - "openai/gpt-5.1": {}, - "openai/gpt-5.2": {}, - "openai/gpt-5.2-pro": {}, - "openai/o3-mini": {}, - "openai/o4-mini": {}, - "openai/o3": {}, - "openai/o3-pro": {}, - "openai/gpt-5-image-mini": {}, - "openai/gpt-5-image": {}, - "z-ai/glm-4.5": {}, - "z-ai/glm-4.5v": {}, - "z-ai/glm-4.5-air": {}, - "z-ai/glm-4.6": {}, - "z-ai/glm-4.6v": {}, - "z-ai/glm-4.7": {}, - "z-ai/glm-5": {}, - "x-ai/grok-4": {}, - "x-ai/grok-3": {}, - "x-ai/grok-3-mini": {}, - "x-ai/grok-4-fast": {}, - "x-ai/grok-4.1-fast": {}, + "anthropic/claude-haiku-4.5": {}, + "anthropic/claude-opus-4.6": {}, + "anthropic/claude-sonnet-4.6": {}, + "deepseek/deepseek-r1-0528": {}, + "deepseek/deepseek-v3.2": {}, + "google/gemini-2.5-flash-lite": {}, + "google/gemini-2.5-pro": {}, + "google/gemini-3-flash-preview": {}, + "google/gemma-2-27b-it": {}, + "meta-llama/llama-4-maverick": {}, + "minimax/minimax-m2.7": {}, + "mistralai/devstral-2512": {}, + "mistralai/mistral-small-2603": {}, + "moonshotai/kimi-k2.5": {}, + "openai/gpt-5-mini": {}, + "openai/gpt-5.2": {}, + "openai/gpt-5.3-codex": {}, + "openai/gpt-5.4": {}, + "openai/gpt-5.4-mini": {}, + "openai/o3": {}, + "openai/o4-mini": {}, + "qwen/qwen2.5-vl-32b-instruct": {}, + "qwen/qwen3-coder-next": {}, + "qwen/qwen3.5-flash-02-23": {}, + "qwen/qwen3.5-plus-02-15": {}, + "x-ai/grok-4.1-fast": {}, + "x-ai/grok-4.20-beta": {}, + "x-ai/grok-code-fast-1": {}, + "z-ai/glm-5-turbo": {}, } if len(ModelManifest.Models) != len(expected) { diff --git a/bridges/ai/canonical_prompt_messages.go b/bridges/ai/canonical_prompt_messages.go index db8c578f..4ff88b83 100644 --- a/bridges/ai/canonical_prompt_messages.go +++ b/bridges/ai/canonical_prompt_messages.go @@ -2,13 +2,11 @@ package ai import ( "strings" - - "github.com/beeper/agentremote/sdk" ) func promptMessagesFromMetadata(meta *MessageMetadata) []PromptMessage { if turnData, ok := canonicalTurnData(meta); ok { - return sdk.PromptMessagesFromTurnData(turnData) + return promptMessagesFromTurnData(turnData) } return nil } @@ -43,6 +41,8 @@ func filterPromptBlocksForHistory(blocks []PromptBlock, injectImages bool) []Pro if injectImages { filtered = append(filtered, block) } + case PromptBlockThinking: + continue default: filtered = append(filtered, block) } @@ -50,20 +50,6 @@ func filterPromptBlocksForHistory(blocks []PromptBlock, injectImages bool) []Pro return filtered } -func textPromptMessage(text string) []PromptMessage { - text = strings.TrimSpace(text) - if text == "" { - return nil - } - return []PromptMessage{{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: text, - }}, - }} -} - func promptTail(ctx PromptContext, count int) []PromptMessage { if count <= 0 || len(ctx.Messages) == 0 { return nil @@ -80,7 +66,7 @@ func setCanonicalTurnDataFromPromptMessages(meta *MessageMetadata, messages []Pr if meta == nil || len(messages) == 0 { return } - if turnData, ok := sdk.TurnDataFromUserPromptMessages(messages); ok { + if turnData, ok := turnDataFromUserPromptMessages(messages); ok { meta.CanonicalTurnData = turnData.ToMap() } else { meta.CanonicalTurnData = nil diff --git a/bridges/ai/canonical_user_messages.go b/bridges/ai/canonical_user_messages.go deleted file mode 100644 index 7a84b89a..00000000 --- a/bridges/ai/canonical_user_messages.go +++ /dev/null @@ -1,25 +0,0 @@ -package ai - -import ( - "strings" - - "maunium.net/go/mautrix/bridgev2/database" -) - -func ensureCanonicalUserMessage(msg *database.Message) { - if msg == nil { - return - } - meta, ok := msg.Metadata.(*MessageMetadata) - if !ok || meta == nil || strings.TrimSpace(meta.Role) != "user" { - return - } - if len(meta.CanonicalTurnData) > 0 { - return - } - - body := strings.TrimSpace(meta.Body) - if body != "" { - setCanonicalTurnDataFromPromptMessages(meta, textPromptMessage(body)) - } -} diff --git a/bridges/ai/chat.go b/bridges/ai/chat.go index 46c85a65..caf7884a 100644 --- a/bridges/ai/chat.go +++ b/bridges/ai/chat.go @@ -122,7 +122,7 @@ func (oc *AIClient) canUseImageGeneration() bool { return false } loginMeta := loginMetadata(oc.UserLogin) - if loginMeta == nil || loginMeta.APIKey == "" { + if loginMeta == nil || strings.TrimSpace(oc.connector.resolveProviderAPIKey(loginMeta)) == "" { return false } switch loginMeta.Provider { diff --git a/bridges/ai/chat_login_redirect_test.go b/bridges/ai/chat_login_redirect_test.go index 42ea3941..9e473323 100644 --- a/bridges/ai/chat_login_redirect_test.go +++ b/bridges/ai/chat_login_redirect_test.go @@ -140,16 +140,16 @@ func TestModelRedirectTarget(t *testing.T) { } func TestResolveModelIDFromManifestAcceptsRawModelID(t *testing.T) { - const modelID = "google/gemini-2.0-flash-lite-001" + const modelID = "google/gemini-3-flash-preview" if got := resolveModelIDFromManifest(modelID); got != modelID { t.Fatalf("expected raw model ID %q to resolve, got %q", modelID, got) } } func TestResolveModelIDFromManifestAcceptsEncodedModelIDViaCandidates(t *testing.T) { - const encoded = "google%2Fgemini-2.0-flash-lite-001" + const encoded = "google%2Fgemini-3-flash-preview" candidates := candidateModelLookupIDs(encoded) - const canonical = "google/gemini-2.0-flash-lite-001" + const canonical = "google/gemini-3-flash-preview" if !slices.Contains(candidates, canonical) { t.Fatalf("expected decoded model candidate in %#v", candidates) } @@ -169,8 +169,8 @@ func TestCandidateModelLookupIDsRejectsMalformedEncoding(t *testing.T) { } func TestParseModelFromGhostIDAcceptsEscapedGhostID(t *testing.T) { - const ghostID = "model-google%2Fgemini-2.0-flash-lite-001" - const want = "google/gemini-2.0-flash-lite-001" + const ghostID = "model-google%2Fgemini-3-flash-preview" + const want = "google/gemini-3-flash-preview" if got := parseModelFromGhostID(ghostID); got != want { t.Fatalf("expected ghost ID %q to parse to %q, got %q", ghostID, want, got) } @@ -190,8 +190,8 @@ func TestResolveIdentifierAcceptsCanonicalModelIdentifier(t *testing.T) { Metadata: &UserLoginMetadata{ ModelCache: &ModelCache{ Models: []ModelInfo{{ - ID: "openai/gpt-5", - Name: "GPT-5", + ID: "openai/gpt-5.4", + Name: "GPT-5.4", }}, }, }, @@ -200,11 +200,11 @@ func TestResolveIdentifierAcceptsCanonicalModelIdentifier(t *testing.T) { connector: &OpenAIConnector{}, } - resp, err := oc.ResolveIdentifier(context.Background(), "model:openai/gpt-5", false) + resp, err := oc.ResolveIdentifier(context.Background(), "model:openai/gpt-5.4", false) if err != nil { t.Fatalf("ResolveIdentifier returned error: %v", err) } - if resp == nil || resp.UserID != modelUserID("openai/gpt-5") { + if resp == nil || resp.UserID != modelUserID("openai/gpt-5.4") { t.Fatalf("expected canonical model identifier to resolve to model ghost, got %#v", resp) } } diff --git a/bridges/ai/client.go b/bridges/ai/client.go index c6faf97f..1466d972 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -26,9 +26,9 @@ import ( "github.com/beeper/agentremote" "github.com/beeper/agentremote/pkg/agents" + integrationruntime "github.com/beeper/agentremote/pkg/integrations/runtime" airuntime "github.com/beeper/agentremote/pkg/runtime" "github.com/beeper/agentremote/pkg/shared/stringutil" - bridgesdk "github.com/beeper/agentremote/sdk" ) var ( @@ -312,11 +312,10 @@ type AIClient struct { // Heartbeat + integrations scheduler *schedulerRuntime - integrationModules map[string]any + integrationModules map[string]integrationruntime.ModuleHooks integrationOrder []string toolRegistry *toolIntegrationRegistry - promptRegistry *promptIntegrationRegistry commandRegistry *commandIntegrationRegistry eventRegistry *eventIntegrationRegistry purgeRegistry *purgeIntegrationRegistry @@ -486,16 +485,14 @@ func initProviderForLogin(key string, meta *UserLoginMetadata, connector *OpenAI } switch meta.Provider { case ProviderOpenRouter: - pdfEngine := connector.Config.Providers.OpenRouter.DefaultPDFEngine - return initOpenRouterProvider(key, connector.resolveOpenRouterBaseURL(), "", pdfEngine, ProviderOpenRouter, log) + return initOpenRouterProvider(key, connector.resolveOpenRouterBaseURL(), "", connector.defaultPDFEngineForInit(), ProviderOpenRouter, log) case ProviderMagicProxy: - baseURL := normalizeProxyBaseURL(meta.BaseURL) + baseURL := normalizeProxyBaseURL(loginCredentialBaseURL(meta)) if baseURL == "" { return nil, errors.New("magic proxy base_url is required") } - pdfEngine := connector.Config.Providers.OpenRouter.DefaultPDFEngine - return initOpenRouterProvider(key, joinProxyPath(baseURL, "/openrouter/v1"), "", pdfEngine, ProviderMagicProxy, log) + return initOpenRouterProvider(key, joinProxyPath(baseURL, "/openrouter/v1"), "", connector.defaultPDFEngineForInit(), ProviderMagicProxy, log) case ProviderOpenAI: openaiURL := connector.resolveOpenAIBaseURL() @@ -510,6 +507,15 @@ func initProviderForLogin(key string, meta *UserLoginMetadata, connector *OpenAI } } +func (oc *OpenAIConnector) defaultPDFEngineForInit() string { + if oc != nil && oc.Config.Agents != nil && oc.Config.Agents.Defaults != nil { + if engine := strings.TrimSpace(oc.Config.Agents.Defaults.PDFEngine); engine != "" { + return engine + } + } + return "mistral-ocr" +} + // initOpenRouterProvider creates an OpenRouter-compatible provider with PDF support. func initOpenRouterProvider(key, url, userID, pdfEngine, providerName string, log zerolog.Logger) (*OpenAIProvider, error) { log.Info(). @@ -604,7 +610,6 @@ func (oc *AIClient) saveUserMessage(ctx context.Context, evt *event.Event, msg * if evt != nil { msg.MXID = evt.ID } - ensureCanonicalUserMessage(msg) if _, err := oc.UserLogin.Bridge.GetGhostByID(ctx, msg.SenderID); err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure user ghost before saving message") } @@ -632,12 +637,14 @@ func (oc *AIClient) dispatchOrQueueCore( shouldSteer := behavior.Steer shouldFollowup := behavior.Followup hasDBMessage := userMessage != nil - queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, oc.roomHasActiveRun(roomID), false) + roomBusy := oc.roomHasActiveRun(roomID) || oc.roomHasPendingQueueWork(roomID) + queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, roomBusy, false) if queueDecision.Action == airuntime.QueueActionInterruptAndRun { oc.cancelRoomRun(roomID) oc.clearPendingQueue(roomID) + roomBusy = false } - if oc.acquireRoom(roomID) { + if !roomBusy && oc.acquireRoom(roomID) { oc.stopQueueTyping(roomID) if hasDBMessage { oc.saveUserMessage(ctx, evt, userMessage) @@ -794,9 +801,9 @@ func (oc *AIClient) processPendingQueue(ctx context.Context, roomID id.RoomID) { } switch item.pending.Type { case pendingTypeText: - promptContext, err = oc.buildContextWithLinkContext(promptCtx, item.pending.Portal, metaSnapshot, prompt, item.rawEventContent, eventID) + promptContext, err = oc.buildCurrentTurnWithLinks(promptCtx, item.pending.Portal, metaSnapshot, prompt, item.rawEventContent, eventID) case pendingTypeImage, pendingTypePDF, pendingTypeAudio, pendingTypeVideo: - promptContext, err = oc.buildContextWithMedia(promptCtx, item.pending.Portal, metaSnapshot, item.pending.MessageBody, item.pending.MediaURL, item.pending.MimeType, item.pending.EncryptedFile, item.pending.Type, eventID) + promptContext, err = oc.buildMediaTurnContext(promptCtx, item.pending.Portal, metaSnapshot, item.pending.MessageBody, item.pending.MediaURL, item.pending.MimeType, item.pending.EncryptedFile, item.pending.Type, eventID) case pendingTypeRegenerate: promptContext, err = oc.buildContextForRegenerate(promptCtx, item.pending.Portal, metaSnapshot, item.pending.MessageBody, item.pending.SourceEventID) case pendingTypeEditRegenerate: @@ -1010,6 +1017,7 @@ func updateGhostLastSync(_ context.Context, ghost *bridgev2.Ghost) bool { func (oc *AIClient) GetCapabilities(ctx context.Context, portal *bridgev2.Portal) *event.RoomFeatures { meta := portalMeta(portal) + isModelRoom := meta != nil && meta.ResolvedTarget != nil && meta.ResolvedTarget.Kind == ResolvedTargetModel // Always recompute effective room capabilities from the resolved room target. modelCaps := oc.getRoomCapabilities(ctx, meta) @@ -1029,11 +1037,17 @@ func (oc *AIClient) GetCapabilities(ctx context.Context, portal *bridgev2.Portal if supportsMsgActions { caps.Reply = event.CapLevelFullySupported - caps.Edit = event.CapLevelFullySupported - caps.EditMaxCount = 10 - caps.EditMaxAge = ptr.Ptr(jsontime.S(AIEditMaxAge)) caps.Reaction = event.CapLevelFullySupported caps.ReactionCount = 1 + if isModelRoom { + caps.Edit = event.CapLevelRejected + caps.EditMaxCount = 0 + caps.EditMaxAge = nil + } else { + caps.Edit = event.CapLevelFullySupported + caps.EditMaxCount = 10 + caps.EditMaxAge = ptr.Ptr(jsontime.S(AIEditMaxAge)) + } } else { // Use explicit rejected levels so features remain visible in // com.beeper.room_features instead of being omitted by omitempty. @@ -1149,18 +1163,32 @@ func (oc *AIClient) defaultModelForProvider() string { if loginMeta == nil { return DefaultModelOpenRouter } - providers := oc.connector.Config.Providers - switch loginMeta.Provider { case ProviderOpenAI: - if providers.OpenAI.DefaultModel != "" { - return providers.OpenAI.DefaultModel - } + return oc.defaultModelSelection(ProviderOpenAI).Primary + case ProviderOpenRouter, ProviderMagicProxy: + return oc.defaultModelSelection(ProviderOpenRouter).Primary + default: + return DefaultModelOpenRouter + } +} + +func (oc *AIClient) defaultModelSelection(provider string) ModelSelectionConfig { + if oc == nil || oc.connector == nil || oc.connector.Config.Agents == nil || oc.connector.Config.Agents.Defaults == nil || oc.connector.Config.Agents.Defaults.Model == nil { + return ModelSelectionConfig{Primary: defaultModelForProviderName(provider)} + } + selection := *oc.connector.Config.Agents.Defaults.Model + if strings.TrimSpace(selection.Primary) == "" { + selection.Primary = defaultModelForProviderName(provider) + } + return selection +} + +func defaultModelForProviderName(provider string) string { + switch strings.ToLower(strings.TrimSpace(provider)) { + case ProviderOpenAI: return DefaultModelOpenAI case ProviderOpenRouter, ProviderMagicProxy: - if providers.OpenRouter.DefaultModel != "" { - return providers.OpenRouter.DefaultModel - } return DefaultModelOpenRouter default: return DefaultModelOpenRouter @@ -1217,8 +1245,8 @@ func (oc *AIClient) profilePromptSupplement() string { func getLinkPreviewConfig(connectorConfig *Config) LinkPreviewConfig { config := DefaultLinkPreviewConfig() - if connectorConfig.LinkPreviews != nil { - cfg := connectorConfig.LinkPreviews + if connectorConfig.Tools.Links != nil { + cfg := connectorConfig.Tools.Links // Apply explicit settings only if they differ from zero values if !cfg.Enabled { config.Enabled = cfg.Enabled @@ -1491,24 +1519,21 @@ func (oc *AIClient) isGroupChat(ctx context.Context, portal *bridgev2.Portal) bo return len(members) > 2 } +func (oc *AIClient) defaultPDFEngine() string { + if oc != nil && oc.connector != nil { + return oc.connector.defaultPDFEngineForInit() + } + return "mistral-ocr" +} + // effectivePDFEngine returns the PDF engine to use for the given portal. -// Priority: room-level PDFConfig > provider-level config > default "mistral-ocr" +// Priority: room-level PDFConfig > agent defaults > default "mistral-ocr" func (oc *AIClient) effectivePDFEngine(meta *PortalMetadata) string { // Room-level override if meta != nil && meta.PDFConfig != nil && meta.PDFConfig.Engine != "" { return meta.PDFConfig.Engine } - - // Provider-level config - loginMeta := loginMetadata(oc.UserLogin) - switch loginMeta.Provider { - case ProviderOpenRouter, ProviderMagicProxy: - if engine := oc.connector.Config.Providers.OpenRouter.DefaultPDFEngine; engine != "" { - return engine - } - } - - return "mistral-ocr" // Default + return oc.defaultPDFEngine() } // validateModel checks if a model is available for this user @@ -1575,8 +1600,8 @@ func resolveModelIDFromManifest(modelID string) string { return "" } -// listAvailableModels fetches models from OpenAI API and caches them -// Returns ModelInfo list from the provider +// listAvailableModels loads models from the derived catalog and caches them. +// The implicit catalog is fed from the OpenRouter-backed manifest. func (oc *AIClient) listAvailableModels(ctx context.Context, forceRefresh bool) ([]ModelInfo, error) { meta := loginMetadata(oc.UserLogin) @@ -1663,77 +1688,18 @@ func (oc *AIClient) updateAssistantGeneratedFiles(ctx context.Context, portal *b oc.Log().Warn().Msg("No assistant message found to update with async GeneratedFiles") } -func (oc *AIClient) promptContextToDispatchMessages( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, - promptContext PromptContext, -) []openai.ChatCompletionMessageParamUnion { - promptMessages := bridgesdk.PromptContextToChatCompletionMessages(promptContext.PromptContext, oc.isOpenRouterProvider()) - promptMessages = oc.augmentPromptWithIntegrations(ctx, portal, meta, promptMessages) - if meta != nil && IsGoogleModel(oc.effectiveModel(meta)) { - promptMessages = SanitizeGoogleTurnOrdering(promptMessages) - } - return promptMessages -} - type historyLoadResult struct { rows []*database.Message hasVision bool resetAt int64 } -// fetchHistoryRows resolves the history limit, extracts the resetAt cutoff, -// and fetches the last N messages. Returns nil when the history limit is zero. -func (oc *AIClient) fetchHistoryRows( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, -) (*historyLoadResult, error) { - historyLimit := oc.historyLimit(ctx, portal, meta) - if historyLimit <= 0 { - return nil, nil - } - resetAt := int64(0) - if meta != nil { - resetAt = meta.SessionResetAt - } - history, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(ctx, portal.PortalKey, historyLimit) - if err != nil { - return nil, fmt.Errorf("failed to load prompt history: %w", err) - } - return &historyLoadResult{ - rows: history, - hasVision: oc.getModelCapabilitiesForMeta(ctx, meta).SupportsVision, - resetAt: resetAt, - }, nil -} - func (oc *AIClient) loadHistoryMessages( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, ) ([]PromptMessage, error) { - hr, err := oc.fetchHistoryRows(ctx, portal, meta) - if err != nil { - return nil, err - } - if hr == nil { - return nil, nil - } - var messages []PromptMessage - for i := len(hr.rows) - 1; i >= 0; i-- { - msgMeta := messageMeta(hr.rows[i]) - if !shouldIncludeInHistory(msgMeta) { - continue - } - if hr.resetAt > 0 && hr.rows[i].Timestamp.UnixMilli() < hr.resetAt { - continue - } - injectImages := hr.hasVision && i < maxHistoryImageMessages - messages = append(messages, oc.historyMessageBundle(ctx, msgMeta, injectImages)...) - } - return messages, nil + return oc.replayHistoryMessages(ctx, portal, meta, historyReplayOptions{mode: historyReplayNormal}) } func (oc *AIClient) buildBaseContext( @@ -1741,9 +1707,9 @@ func (oc *AIClient) buildBaseContext( portal *bridgev2.Portal, meta *PortalMetadata, ) (PromptContext, error) { - var promptContext PromptContext - bridgesdk.AppendChatMessagesToPromptContext(&promptContext.PromptContext, maybePrependSessionGreeting(ctx, portal, meta, nil, oc.log)) - bridgesdk.AppendChatMessagesToPromptContext(&promptContext.PromptContext, oc.buildSystemMessages(ctx, portal, meta)) + promptContext := PromptContext{ + SystemPrompt: oc.buildConversationSystemPromptText(ctx, portal, meta, true), + } historyMessages, err := oc.loadHistoryMessages(ctx, portal, meta) if err != nil { @@ -1754,18 +1720,6 @@ func (oc *AIClient) buildBaseContext( return promptContext, nil } -func (oc *AIClient) buildBasePrompt( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, -) ([]openai.ChatCompletionMessageParamUnion, error) { - promptContext, err := oc.buildBaseContext(ctx, portal, meta) - if err != nil { - return nil, err - } - return oc.promptContextToDispatchMessages(ctx, portal, meta, promptContext), nil -} - func (oc *AIClient) applyAbortHint(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, body string) string { if meta == nil || !meta.AbortedLastRun { return body @@ -1788,8 +1742,9 @@ type inboundPromptResult struct { } // prepareInboundPromptContext builds the base context, resolves inbound context, -// appends the meta system prompt, resolves body overrides, and applies the abort hint. -// Callers must call applyUntrustedPrefix at the appropriate point in message assembly. +// appends trusted inbound metadata to the system prompt, resolves body overrides, +// and applies the abort hint. Untrusted inbound prefixes are returned separately +// so callers can place them deterministically in the user prompt body. func (oc *AIClient) prepareInboundPromptContext( ctx context.Context, portal *bridgev2.Portal, @@ -1797,12 +1752,19 @@ func (oc *AIClient) prepareInboundPromptContext( userText string, eventID id.EventID, ) (inboundPromptResult, error) { - promptContext, err := oc.buildBaseContext(ctx, portal, meta) + promptContext := PromptContext{ + SystemPrompt: oc.buildConversationSystemPromptText(ctx, portal, meta, true), + } + historyMessages, err := oc.replayHistoryMessages(ctx, portal, meta, historyReplayOptions{ + mode: historyReplayNormal, + excludeMessageID: networkid.MessageID(eventID), + }) if err != nil { return inboundPromptResult{}, err } + promptContext.Messages = append(promptContext.Messages, historyMessages...) inboundCtx := oc.resolvePromptInboundContext(ctx, portal, userText, eventID) - bridgesdk.AppendPromptText(&promptContext.SystemPrompt, airuntime.BuildInboundMetaSystemPrompt(inboundCtx)) + AppendPromptText(&promptContext.SystemPrompt, airuntime.BuildInboundMetaSystemPrompt(inboundCtx)) resolved := strings.TrimSpace(userText) if body := strings.TrimSpace(inboundCtx.BodyForAgent); body != "" { @@ -1819,50 +1781,6 @@ func (oc *AIClient) prepareInboundPromptContext( }, nil } -func (r *inboundPromptResult) applyUntrustedPrefix() { - if r.UntrustedPrefix != "" { - r.ResolvedBody = r.UntrustedPrefix + "\n\n" + r.ResolvedBody - } -} - -func (oc *AIClient) buildContextWithLinkContext( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, - latest string, - rawEventContent map[string]any, - eventID id.EventID, -) (PromptContext, error) { - result, err := oc.prepareInboundPromptContext(ctx, portal, meta, latest, eventID) - if err != nil { - return PromptContext{}, err - } - - if linkContext := oc.buildLinkContext(ctx, latest, rawEventContent); linkContext != "" { - result.ResolvedBody += linkContext - } - - if portal != nil && portal.MXID != "" { - reactionFeedback := DrainReactionFeedback(portal.MXID) - if len(reactionFeedback) > 0 { - if feedbackText := FormatReactionFeedback(reactionFeedback); feedbackText != "" { - result.ResolvedBody = feedbackText + "\n" + result.ResolvedBody - } - } - } - - result.applyUntrustedPrefix() - - result.PromptContext.Messages = append(result.PromptContext.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: result.ResolvedBody, - }}, - }) - return result.PromptContext, nil -} - // buildLinkContext extracts URLs from the message, fetches previews, and returns formatted context. func (oc *AIClient) buildLinkContext(ctx context.Context, message string, rawEventContent map[string]any) string { config := getLinkPreviewConfig(&oc.connector.Config) @@ -1923,8 +1841,8 @@ func (oc *AIClient) buildLinkContext(ctx context.Context, message string, rawEve return FormatPreviewsForContext(allPreviews, config.MaxContentChars) } -// buildPromptWithMedia builds a prompt with media content. -func (oc *AIClient) buildContextWithMedia( +// buildMediaTurnContext builds a prompt turn with media content. +func (oc *AIClient) buildMediaTurnContext( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, @@ -1935,67 +1853,15 @@ func (oc *AIClient) buildContextWithMedia( mediaType pendingMessageType, eventID id.EventID, ) (PromptContext, error) { - result, err := oc.prepareInboundPromptContext(ctx, portal, meta, caption, eventID) - if err != nil { - return PromptContext{}, err - } - result.applyUntrustedPrefix() - blocks := make([]PromptBlock, 0, 2) - if strings.TrimSpace(result.ResolvedBody) != "" { - blocks = append(blocks, PromptBlock{Type: PromptBlockText, Text: result.ResolvedBody}) - } - - switch mediaType { - case pendingTypeImage: - b64Data, actualMimeType, err := oc.downloadMediaBase64(ctx, mediaURL, encryptedFile, 20, mimeType) // 20MB limit for images - if err != nil { - return PromptContext{}, fmt.Errorf("failed to download image: %w", err) - } - blocks = append(blocks, PromptBlock{ - Type: PromptBlockImage, - ImageB64: b64Data, - MimeType: actualMimeType, - }) - - case pendingTypePDF: - b64Data, actualMimeType, err := oc.downloadMediaBase64(ctx, mediaURL, encryptedFile, 50, mimeType) // 50MB limit - if err != nil { - return PromptContext{}, fmt.Errorf("failed to download PDF: %w", err) - } - if actualMimeType == "" { - actualMimeType = "application/pdf" - } - blocks = append(blocks, PromptBlock{ - Type: PromptBlockFile, - FileB64: bridgesdk.BuildDataURL(actualMimeType, b64Data), - Filename: "document.pdf", - MimeType: actualMimeType, - }) - - case pendingTypeAudio: - if strings.TrimSpace(result.ResolvedBody) == "" { - blocks = append(blocks, PromptBlock{ - Type: PromptBlockText, - Text: fmt.Sprintf("Audio attachment: %s", mediaURL), - }) - } - - case pendingTypeVideo: - if strings.TrimSpace(result.ResolvedBody) == "" { - blocks = append(blocks, PromptBlock{ - Type: PromptBlockText, - Text: fmt.Sprintf("Video attachment: %s", mediaURL), - }) - } - - default: - return PromptContext{}, fmt.Errorf("unsupported media type: %s", mediaType) - } - result.PromptContext.Messages = append(result.PromptContext.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: blocks, + return oc.buildPromptContextForTurn(ctx, portal, meta, caption, eventID, currentTurnPromptOptions{ + currentTurnTextOptions: currentTurnTextOptions{includeLinkScope: true}, + attachment: &turnAttachmentOptions{ + mediaURL: mediaURL, + mimeType: mimeType, + encryptedFile: encryptedFile, + mediaType: mediaType, + }, }) - return result.PromptContext, nil } // buildPromptUpToMessage builds a prompt including messages up to and including the specified message @@ -2006,55 +1872,21 @@ func (oc *AIClient) buildContextUpToMessage( targetMessageID networkid.MessageID, newBody string, ) (PromptContext, error) { - var promptContext PromptContext - bridgesdk.AppendChatMessagesToPromptContext(&promptContext.PromptContext, oc.buildSystemMessages(ctx, portal, meta)) - - hr, err := oc.fetchHistoryRows(ctx, portal, meta) + base := PromptContext{ + SystemPrompt: oc.buildConversationSystemPromptText(ctx, portal, meta, false), + } + historyMessages, err := oc.replayHistoryMessages(ctx, portal, meta, historyReplayOptions{ + mode: historyReplayRewrite, + targetMessageID: targetMessageID, + }) if err != nil { return PromptContext{}, err } - if hr != nil { - // Add messages up to the target message, replacing the target with newBody - for i := len(hr.rows) - 1; i >= 0; i-- { - msg := hr.rows[i] - msgMeta := messageMeta(msg) - - // Stop after adding the target message - if msg.ID == targetMessageID { - body := newBody - promptContext.Messages = append(promptContext.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: body, - }}, - }) - break - } - - if !shouldIncludeInHistory(msgMeta) { - continue - } - if hr.resetAt > 0 && msg.Timestamp.UnixMilli() < hr.resetAt { - continue - } - - injectImages := hr.hasVision && i < maxHistoryImageMessages - promptContext.Messages = append(promptContext.Messages, oc.historyMessageBundle(ctx, msgMeta, injectImages)...) - } - } else { - body := strings.TrimSpace(newBody) - body = airuntime.SanitizeChatMessageForDisplay(body, true) - promptContext.Messages = append(promptContext.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: body, - }}, - }) - } - - return promptContext, nil + base.Messages = append(base.Messages, historyMessages...) + body := strings.TrimSpace(newBody) + body = airuntime.SanitizeChatMessageForDisplay(body, true) + base.Messages = append(base.Messages, newUserTextPromptMessage(body)) + return base, nil } // downloadAndEncodeMedia downloads media and returns base64-encoded data. @@ -2071,27 +1903,6 @@ func (oc *AIClient) downloadAndEncodeMedia(ctx context.Context, mxcURL string, e return base64.StdEncoding.EncodeToString(data), mimeType, nil } -// getAudioFormat extracts the audio format from a MIME type for OpenRouter API -func getAudioFormat(mimeType string) string { - switch mimeType { - case "audio/wav", "audio/x-wav": - return "wav" - case "audio/mpeg", "audio/mp3": - return "mp3" - case "audio/webm": - return "webm" - case "audio/ogg": - return "ogg" - case "audio/flac": - return "flac" - case "audio/mp4", "audio/x-m4a": - return "mp4" - default: - // Default to mp3 for unknown formats - return "mp3" - } -} - // ensureGhostDisplayName ensures the ghost has its display name set before sending messages. // This fixes the issue where ghosts appear with raw user IDs instead of formatted names. func (oc *AIClient) ensureGhostDisplayName(ctx context.Context, modelID string) { @@ -2223,8 +2034,9 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { ctx = withInboundContext(ctx, inboundCtx) rawEventContent := map[string]any(nil) if last.Event != nil && last.Event.Content.Raw != nil { - rawEventContent = last.Event.Content.Raw + rawEventContent = clonePendingRawMap(last.Event.Content.Raw) } + pendingEvent := snapshotPendingEvent(last.Event) extraStatusEvents := make([]*event.Event, 0, len(entries)-1) if len(entries) > 1 { @@ -2240,7 +2052,7 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { } // Build prompt with combined body - promptContext, err := oc.buildContextWithLinkContext(statusCtx, last.Portal, last.Meta, combinedBody, rawEventContent, last.Event.ID) + promptContext, err := oc.buildCurrentTurnWithLinks(statusCtx, last.Portal, last.Meta, combinedBody, rawEventContent, last.Event.ID) if err != nil { oc.loggerForContext(ctx).Err(err).Msg("Failed to build prompt for debounced messages") oc.notifyMatrixSendFailure(statusCtx, last.Portal, last.Event, err) @@ -2261,7 +2073,6 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { Timestamp: agentremote.MatrixEventTimestamp(last.Event), } setCanonicalTurnDataFromPromptMessages(userMessage.Metadata.(*MessageMetadata), promptTail(promptContext, 1)) - ensureCanonicalUserMessage(userMessage) // Save user message to database - we must do this ourselves since we already // returned Pending: true to the bridge framework when debouncing started @@ -2283,7 +2094,7 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { } pending := pendingMessage{ - Event: last.Event, + Event: pendingEvent, Portal: last.Portal, Meta: last.Meta, InboundContext: &inboundCtx, @@ -2300,14 +2111,14 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { } queueItem := pendingQueueItem{ pending: pending, - messageID: string(last.Event.ID), + messageID: string(pendingEvent.ID), summaryLine: combinedRaw, enqueuedAt: time.Now().UnixMilli(), rawEventContent: rawEventContent, } queueSettings, _, _, _ := oc.resolveQueueSettingsForPortal(statusCtx, last.Portal, last.Meta, "", airuntime.QueueInlineOptions{}) - _, _ = oc.dispatchOrQueue(statusCtx, last.Event, last.Portal, last.Meta, nil, queueItem, queueSettings, promptContext) + _, _ = oc.dispatchOrQueue(statusCtx, pendingEvent, last.Portal, last.Meta, nil, queueItem, queueSettings, promptContext) } diff --git a/bridges/ai/client_capabilities_test.go b/bridges/ai/client_capabilities_test.go index f974f9b0..8bd55411 100644 --- a/bridges/ai/client_capabilities_test.go +++ b/bridges/ai/client_capabilities_test.go @@ -29,8 +29,8 @@ func TestGetCapabilities_ModelRoomAllowsReplyEditReaction(t *testing.T) { if caps.Reply != event.CapLevelFullySupported { t.Fatalf("expected reply fully supported in model room, got %v", caps.Reply) } - if caps.Edit != event.CapLevelFullySupported { - t.Fatalf("expected edit fully supported in model room, got %v", caps.Edit) + if caps.Edit != event.CapLevelRejected { + t.Fatalf("expected edit rejected in model room, got %v", caps.Edit) } if caps.Reaction != event.CapLevelFullySupported { t.Fatalf("expected reaction fully supported in model room, got %v", caps.Reaction) @@ -40,8 +40,12 @@ func TestGetCapabilities_ModelRoomAllowsReplyEditReaction(t *testing.T) { if err != nil { t.Fatalf("failed to marshal room features: %v", err) } - if !strings.Contains(string(raw), `"reaction":2`) { - t.Fatalf("expected serialized room features to contain reaction=2, got: %s", string(raw)) + rawJSON := string(raw) + if !strings.Contains(rawJSON, `"reaction":2`) { + t.Fatalf("expected serialized room features to contain reaction=2, got: %s", rawJSON) + } + if !strings.Contains(rawJSON, `"edit":-2`) { + t.Fatalf("expected serialized room features to contain edit=-2, got: %s", rawJSON) } } diff --git a/bridges/ai/command_registry.go b/bridges/ai/command_registry.go index 84fba72f..57afaefe 100644 --- a/bridges/ai/command_registry.go +++ b/bridges/ai/command_registry.go @@ -96,7 +96,6 @@ func registerModuleCommands(defs []integrationruntime.CommandDefinition) { ce.Ctx, ce.Portal, meta, - ce, commandName, ce.Args, ce.RawArgs, diff --git a/bridges/ai/compaction_summarization.go b/bridges/ai/compaction_summarization.go index 4d425d22..3883edd1 100644 --- a/bridges/ai/compaction_summarization.go +++ b/bridges/ai/compaction_summarization.go @@ -614,18 +614,20 @@ func injectSystemPromptAtFirstNonSystem( func (oc *AIClient) applyCompactionModelSummaryAndRefresh( ctx context.Context, meta *PortalMetadata, - originalPrompt []openai.ChatCompletionMessageParamUnion, - compactedPrompt []openai.ChatCompletionMessageParamUnion, + originalPrompt PromptContext, + compactedPrompt PromptContext, decision airuntime.CompactionDecision, contextWindowTokens int, -) []openai.ChatCompletionMessageParamUnion { - out := compactedPrompt +) PromptContext { + originalMessages := promptContextToChatCompletionMessages(originalPrompt, false) + compactedMessages := promptContextToChatCompletionMessages(compactedPrompt, false) + out := compactedMessages if oc.pruningSummarizationEnabled() { - dropped := selectDroppedCompactionMessages(originalPrompt, compactedPrompt, decision.DroppedCount) + dropped := selectDroppedCompactionMessages(originalMessages, compactedMessages, decision.DroppedCount) if len(dropped) > 0 { model := resolveCompactionSummaryModel(oc.effectiveModel(meta), oc.pruningSummarizationModel()) allMessages := slices.Clone(dropped) - allMessages = append(allMessages, compactedPrompt...) + allMessages = append(allMessages, compactedMessages...) adaptive := computeCompactionAdaptiveChunkRatio(allMessages, model, contextWindowTokens) maxChunkTokens := int(math.Floor(float64(contextWindowTokens)*adaptive)) - compactionSummarizationOverhead if maxChunkTokens <= 0 { @@ -653,5 +655,5 @@ func (oc *AIClient) applyCompactionModelSummaryAndRefresh( if refresh := strings.TrimSpace(oc.pruningPostCompactionRefreshPrompt()); refresh != "" { out = injectSystemPromptAtFirstNonSystem(out, refresh) } - return out + return chatMessagesToPromptContext(out) } diff --git a/bridges/ai/config_example_sync_test.go b/bridges/ai/config_example_sync_test.go new file mode 100644 index 00000000..1d275655 --- /dev/null +++ b/bridges/ai/config_example_sync_test.go @@ -0,0 +1,52 @@ +package ai + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "testing" + + "gopkg.in/yaml.v3" +) + +func TestMainConfigExampleNetworkBlockMatchesEmbeddedExample(t *testing.T) { + mainConfigPath := filepath.Join("..", "..", "config.example.yaml") + mainData, err := os.ReadFile(mainConfigPath) + if err != nil { + t.Fatalf("read %s: %v", mainConfigPath, err) + } + + var mainDoc map[string]any + if err := yaml.Unmarshal(mainData, &mainDoc); err != nil { + t.Fatalf("unmarshal %s: %v", mainConfigPath, err) + } + + networkRaw, ok := mainDoc["network"] + if !ok { + t.Fatalf("%s is missing top-level network block", mainConfigPath) + } + network, ok := networkRaw.(map[string]any) + if !ok { + t.Fatalf("%s network block has unexpected type %T", mainConfigPath, networkRaw) + } + + embeddedPath := "integrations_example-config.yaml" + embeddedData, err := os.ReadFile(embeddedPath) + if err != nil { + t.Fatalf("read %s: %v", embeddedPath, err) + } + + var embeddedDoc map[string]any + if err := yaml.Unmarshal(embeddedData, &embeddedDoc); err != nil { + t.Fatalf("unmarshal %s: %v", embeddedPath, err) + } + + if reflect.DeepEqual(network, embeddedDoc) { + return + } + + gotJSON, _ := json.MarshalIndent(network, "", " ") + wantJSON, _ := json.MarshalIndent(embeddedDoc, "", " ") + t.Fatalf("config.example.yaml network block drifted from %s\n--- got ---\n%s\n--- want ---\n%s", embeddedPath, gotJSON, wantJSON) +} diff --git a/bridges/ai/connector.go b/bridges/ai/connector.go index ed9238bf..1cc0c84e 100644 --- a/bridges/ai/connector.go +++ b/bridges/ai/connector.go @@ -37,7 +37,7 @@ type OpenAIConnector struct { br *bridgev2.Bridge Config Config db *dbutil.Database - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*AIClient, *Config] clientsMu sync.Mutex clients map[networkid.UserLoginID]bridgev2.NetworkAPI @@ -55,10 +55,16 @@ func (oc *OpenAIConnector) applyRuntimeDefaults() { oc.Config.ModelCacheDuration = 6 * time.Hour } bridgesdk.ApplyDefaultCommandPrefix(&oc.Config.Bridge.CommandPrefix, "!ai") - if oc.Config.Pruning == nil { - oc.Config.Pruning = airuntime.DefaultPruningConfig() + if oc.Config.Agents == nil { + oc.Config.Agents = &AgentsConfig{} + } + if oc.Config.Agents.Defaults == nil { + oc.Config.Agents.Defaults = &AgentDefaultsConfig{} + } + if oc.Config.Agents.Defaults.Compaction == nil { + oc.Config.Agents.Defaults.Compaction = airuntime.DefaultPruningConfig() } else { - oc.Config.Pruning = airuntime.ApplyPruningDefaults(oc.Config.Pruning) + oc.Config.Agents.Defaults.Compaction = airuntime.ApplyPruningDefaults(oc.Config.Agents.Defaults.Compaction) } } diff --git a/bridges/ai/constructors.go b/bridges/ai/constructors.go index 6f8d3687..920359e7 100644 --- a/bridges/ai/constructors.go +++ b/bridges/ai/constructors.go @@ -3,7 +3,6 @@ package ai import ( "context" - "go.mau.fi/util/configupgrade" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/commands" @@ -18,7 +17,7 @@ func NewAIConnector() *OpenAIConnector { oc := &OpenAIConnector{ clients: make(map[networkid.UserLoginID]bridgev2.NetworkAPI), } - oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*AIClient, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "ai", Description: "AI Chats for Beeper, built on mautrix-go bridgev2.", ProtocolID: "ai", @@ -59,15 +58,14 @@ func NewAIConnector() *OpenAIConnector { BeeperBridgeType: "ai", DefaultPort: 29345, DefaultCommandPrefix: func() string { - return oc.Config.Bridge.CommandPrefix + return bridgesdk.ResolveCommandPrefix(oc.Config.Bridge.CommandPrefix, "!ai") }, - ExampleConfig: exampleNetworkConfig, - ConfigData: &oc.Config, - ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + ExampleConfig: exampleNetworkConfig, + ConfigData: &oc.Config, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, FillBridgeInfo: func(portal *bridgev2.Portal, content *event.BridgeEventContent) { applyAgentRemoteBridgeInfo(portal, portalMeta(portal), content) }, diff --git a/bridges/ai/constructors_test.go b/bridges/ai/constructors_test.go index 362c5280..fa7fca51 100644 --- a/bridges/ai/constructors_test.go +++ b/bridges/ai/constructors_test.go @@ -39,6 +39,9 @@ func TestNewAIConnectorUsesSDKConfig(t *testing.T) { if name.DefaultPort != 29345 { t.Fatalf("unexpected default port %d", name.DefaultPort) } + if name.DefaultCommandPrefix != "!ai" { + t.Fatalf("unexpected default command prefix %q", name.DefaultCommandPrefix) + } } func TestNewAIConnectorInitializesClientCacheMap(t *testing.T) { diff --git a/bridges/ai/desktop_api_sessions.go b/bridges/ai/desktop_api_sessions.go index 05d4123d..29121000 100644 --- a/bridges/ai/desktop_api_sessions.go +++ b/bridges/ai/desktop_api_sessions.go @@ -158,11 +158,11 @@ func (oc *AIClient) desktopAPIInstances() map[string]DesktopAPIInstance { if oc == nil || oc.UserLogin == nil { return instances } - meta := loginMetadata(oc.UserLogin) - if meta == nil || meta.ServiceTokens == nil { + creds := loginCredentials(loginMetadata(oc.UserLogin)) + if creds == nil || creds.ServiceTokens == nil { return instances } - for name, instance := range meta.ServiceTokens.DesktopAPIInstances { + for name, instance := range creds.ServiceTokens.DesktopAPIInstances { key := normalizeDesktopInstanceName(name) if key == "" { continue @@ -172,9 +172,11 @@ func (oc *AIClient) desktopAPIInstances() map[string]DesktopAPIInstance { } instances[key] = instance } - if token := strings.TrimSpace(meta.ServiceTokens.DesktopAPI); token != "" { - if _, ok := instances[desktopDefaultInstance]; !ok { - instances[desktopDefaultInstance] = DesktopAPIInstance{Token: token} + if token := strings.TrimSpace(creds.ServiceTokens.DesktopAPI); token != "" { + instance := instances[desktopDefaultInstance] + if strings.TrimSpace(instance.Token) == "" { + instance.Token = token + instances[desktopDefaultInstance] = instance } } return instances diff --git a/bridges/ai/desktop_api_sessions_test.go b/bridges/ai/desktop_api_sessions_test.go new file mode 100644 index 00000000..3b97cb4e --- /dev/null +++ b/bridges/ai/desktop_api_sessions_test.go @@ -0,0 +1,43 @@ +package ai + +import ( + "testing" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +func TestDesktopAPIInstancesMergesFallbackTokenIntoDefaultInstance(t *testing.T) { + client := &AIClient{ + UserLogin: &bridgev2.UserLogin{ + UserLogin: &database.UserLogin{ + ID: networkid.UserLoginID("login"), + Metadata: &UserLoginMetadata{ + Credentials: &LoginCredentials{ + ServiceTokens: &ServiceTokens{ + DesktopAPI: "fallback-token", + DesktopAPIInstances: map[string]DesktopAPIInstance{ + "default": {BaseURL: "https://desktop.example"}, + }, + }, + }, + }, + }, + Log: zerolog.Nop(), + }, + } + + instances := client.desktopAPIInstances() + got, ok := instances[desktopDefaultInstance] + if !ok { + t.Fatal("expected default desktop API instance") + } + if got.Token != "fallback-token" { + t.Fatalf("expected fallback token to be merged, got %#v", got) + } + if got.BaseURL != "https://desktop.example" { + t.Fatalf("expected base URL to be preserved, got %#v", got) + } +} diff --git a/bridges/ai/error_logging.go b/bridges/ai/error_logging.go index add53183..13591dc5 100644 --- a/bridges/ai/error_logging.go +++ b/bridges/ai/error_logging.go @@ -9,14 +9,14 @@ import ( "github.com/rs/zerolog" ) -func logResponsesFailure(log zerolog.Logger, err error, params responses.ResponseNewParams, meta *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion, stage string) { - logProviderFailure(log, err, meta, messages, stage, "Responses API failure", func(event *zerolog.Event) { +func logResponsesFailure(log zerolog.Logger, err error, params responses.ResponseNewParams, meta *PortalMetadata, prompt PromptContext, stage string) { + logProviderFailure(log, err, meta, prompt, stage, "Responses API failure", func(event *zerolog.Event) { addResponsesParamsSummary(event, params) }) } -func logChatCompletionsFailure(log zerolog.Logger, err error, params openai.ChatCompletionNewParams, meta *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion, stage string) { - logProviderFailure(log, err, meta, messages, stage, "Chat Completions failure", func(event *zerolog.Event) { +func logChatCompletionsFailure(log zerolog.Logger, err error, params openai.ChatCompletionNewParams, meta *PortalMetadata, prompt PromptContext, stage string) { + logProviderFailure(log, err, meta, prompt, stage, "Chat Completions failure", func(event *zerolog.Event) { addChatParamsSummary(event, params) }) } @@ -25,13 +25,13 @@ func logProviderFailure( log zerolog.Logger, err error, meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, stage string, msg string, addSummary func(*zerolog.Event), ) { event := log.Error().Err(err).Str("stage", stage) - addRequestSummary(event, meta, messages) + addRequestSummary(event, meta, prompt) if addSummary != nil { addSummary(event) } @@ -39,7 +39,7 @@ func logProviderFailure( event.Msg(msg) } -func addRequestSummary(event *zerolog.Event, metadata *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion) { +func addRequestSummary(event *zerolog.Event, metadata *PortalMetadata, prompt PromptContext) { if event == nil { return } @@ -54,9 +54,8 @@ func addRequestSummary(event *zerolog.Event, metadata *PortalMetadata, messages event.Str("runtime_model_override", metadata.RuntimeModelOverride) } } - event.Int("message_count", len(messages)) - event.Bool("has_audio", hasAudioContent(messages)) - event.Bool("has_multimodal", hasMultimodalContent(messages)) + event.Int("message_count", len(prompt.Messages)) + event.Bool("has_multimodal", promptHasMultimodalContent(prompt)) } func addResponsesParamsSummary(event *zerolog.Event, params responses.ResponseNewParams) { diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index 90248b20..4604d084 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -17,7 +17,6 @@ import ( "github.com/beeper/agentremote" airuntime "github.com/beeper/agentremote/pkg/runtime" "github.com/beeper/agentremote/pkg/shared/stringutil" - bridgesdk "github.com/beeper/agentremote/sdk" ) func messageSendStatusError(err error, message string, reason event.MessageStatusReason) error { @@ -251,15 +250,16 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri // Get raw event content for link previews var rawEventContent map[string]any if msg.Event != nil && msg.Event.Content.Raw != nil { - rawEventContent = msg.Event.Content.Raw + rawEventContent = clonePendingRawMap(msg.Event.Content.Raw) } + pendingEvent := snapshotPendingEvent(msg.Event) eventID := id.EventID("") if msg.Event != nil { eventID = msg.Event.ID } - promptContext, err := oc.buildContextWithLinkContext(runCtx, portal, runMeta, body, rawEventContent, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(runCtx, portal, runMeta, body, rawEventContent, eventID) if err != nil { return nil, messageSendStatusError(err, "Couldn't prepare the message. Try again.", "") } @@ -280,7 +280,7 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri } pending := pendingMessage{ - Event: msg.Event, + Event: pendingEvent, Portal: portal, Meta: runMeta, InboundContext: &inboundCtx, @@ -301,7 +301,7 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri enqueuedAt: time.Now().UnixMilli(), rawEventContent: rawEventContent, } - dbMsg, isPending := oc.dispatchOrQueue(runCtx, msg.Event, portal, runMeta, userMessage, queueItem, queueSettings, promptContext) + dbMsg, isPending := oc.dispatchOrQueue(runCtx, pendingEvent, portal, runMeta, userMessage, queueItem, queueSettings, promptContext) return &bridgev2.MatrixMessageResponse{ DB: dbMsg, @@ -323,15 +323,17 @@ func (oc *AIClient) HandleMatrixTyping(ctx context.Context, typing *bridgev2.Mat // HandleMatrixEdit handles edits to previously sent messages func (oc *AIClient) HandleMatrixEdit(ctx context.Context, edit *bridgev2.MatrixEdit) error { - if edit.Content == nil || edit.EditTarget == nil { - return errors.New("invalid edit: missing content or target") - } - portal := edit.Portal if portal == nil { return errors.New("portal is nil") } meta := portalMeta(portal) + if meta != nil && meta.ResolvedTarget != nil && meta.ResolvedTarget.Kind == ResolvedTargetModel { + return bridgev2.ErrEditsNotSupportedInPortal + } + if edit.Content == nil || edit.EditTarget == nil { + return errors.New("invalid edit: missing content or target") + } // Get the new message body newBody := strings.TrimSpace(edit.Content.Body) @@ -436,8 +438,9 @@ func (oc *AIClient) regenerateFromEdit( queueSettings, _, _, _ := oc.resolveQueueSettingsForPortal(ctx, portal, meta, "", airuntime.QueueInlineOptions{}) isGroup := oc.isGroupChat(ctx, portal) + pendingEvent := snapshotPendingEvent(evt) pending := pendingMessage{ - Event: evt, + Event: pendingEvent, Portal: portal, Meta: meta, Type: pendingTypeEditRegenerate, @@ -454,7 +457,7 @@ func (oc *AIClient) regenerateFromEdit( summaryLine: newBody, enqueuedAt: time.Now().UnixMilli(), } - oc.dispatchOrQueueCore(ctx, evt, portal, meta, nil, queueItem, queueSettings, promptContext) + oc.dispatchOrQueueCore(ctx, pendingEvent, portal, meta, nil, queueItem, queueSettings, promptContext) return nil } @@ -572,11 +575,12 @@ func (oc *AIClient) handleMediaMessage( eventID = msg.Event.ID } - // Check capability (PDF has special OpenRouter handling via file-parser plugin) + // PDFs are normalized into file-context text before prompt assembly, so they + // do not require native provider file support. modelCaps := oc.getModelCapabilitiesForMeta(ctx, meta) supportsMedia := config.capabilityCheck(&modelCaps) - if isPDF && !supportsMedia && oc.isOpenRouterProvider() { - supportsMedia = true // OpenRouter supports PDF via file-parser plugin + if isPDF { + supportsMedia = true } queueSettings, _, _, _ := oc.resolveQueueSettingsForPortal(ctx, portal, meta, "", airuntime.QueueInlineOptions{}) @@ -599,12 +603,13 @@ func (oc *AIClient) handleMediaMessage( if msg.Content.File != nil { encryptedFile = msg.Content.File } + pendingEvent := snapshotPendingEvent(msg.Event) dispatchTextOnly := func(rawBody string) (*bridgev2.MatrixMessageResponse, error) { inboundCtx := oc.buildMatrixInboundContext(portal, msg.Event, rawBody, senderName, roomName, isGroup) promptCtx := withInboundContext(ctx, inboundCtx) body := oc.buildMatrixInboundBody(ctx, portal, meta, msg.Event, rawBody, senderName, roomName, isGroup) - promptContext, err := oc.buildContextWithLinkContext(promptCtx, portal, meta, body, nil, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(promptCtx, portal, meta, body, nil, eventID) if err != nil { return nil, messageSendStatusError(err, "Couldn't prepare the message. Try again.", "") } @@ -623,7 +628,7 @@ func (oc *AIClient) handleMediaMessage( userMessage.SendTxnID = networkid.RawTransactionID(msg.InputTransactionID) } pending := pendingMessage{ - Event: msg.Event, + Event: pendingEvent, Portal: portal, Meta: meta, InboundContext: &inboundCtx, @@ -638,7 +643,7 @@ func (oc *AIClient) handleMediaMessage( summaryLine: rawBody, enqueuedAt: time.Now().UnixMilli(), } - dbMsg, isPending := oc.dispatchOrQueue(promptCtx, msg.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) + dbMsg, isPending := oc.dispatchOrQueue(promptCtx, pendingEvent, portal, meta, userMessage, queueItem, queueSettings, promptContext) return &bridgev2.MatrixMessageResponse{ DB: dbMsg, Pending: isPending, @@ -703,30 +708,6 @@ func (oc *AIClient) handleMediaMessage( } } - // If model lacks audio but agent supports audio understanding, analyze audio first. - if msgType == event.MsgAudio { - audioModel, audioFallback := oc.resolveModelForCapability(ctx, meta, func(caps ModelCapabilities) bool { return caps.SupportsAudio }, oc.resolveAudioUnderstandingModel) - if resp, err := oc.dispatchMediaUnderstandingFallback( - ctx, - audioModel, - audioFallback, - string(mediaURL), - mimeType, - encryptedFile, - caption, - hasUserCaption, - buildMediaUnderstandingPrompt(MediaCapabilityAudio), - oc.analyzeAudioWithModel, - buildMediaUnderstandingMessage("Audio", "Transcript"), - "Audio understanding failed", - "audio understanding produced empty result", - "Couldn't analyze the audio. Try again, or switch to an audio-capable model with !ai model.", - dispatchTextOnly, - ); resp != nil || err != nil { - return resp, err - } - } - return nil, agentremote.UnsupportedMessageStatus(fmt.Errorf( "current model (%s) does not support %s; switch to a capable model using !ai model", oc.effectiveModel(meta), config.capabilityName, @@ -737,7 +718,7 @@ func (oc *AIClient) handleMediaMessage( captionForPrompt := oc.buildMatrixInboundBody(ctx, portal, meta, msg.Event, caption, senderName, roomName, isGroup) captionInboundCtx := oc.buildMatrixInboundContext(portal, msg.Event, caption, senderName, roomName, isGroup) promptCtx := withInboundContext(ctx, captionInboundCtx) - promptContext, err := oc.buildContextWithMedia(promptCtx, portal, meta, captionForPrompt, string(mediaURL), mimeType, encryptedFile, config.msgType, eventID) + promptContext, err := oc.buildMediaTurnContext(promptCtx, portal, meta, captionForPrompt, string(mediaURL), mimeType, encryptedFile, config.msgType, eventID) if err != nil { return nil, messageSendStatusError(err, "Couldn't prepare the media message. Try again.", "") } @@ -770,7 +751,7 @@ func (oc *AIClient) handleMediaMessage( } pending := pendingMessage{ - Event: msg.Event, + Event: snapshotPendingEvent(msg.Event), Portal: portal, Meta: meta, InboundContext: &captionInboundCtx, @@ -788,7 +769,7 @@ func (oc *AIClient) handleMediaMessage( summaryLine: rawCaption, enqueuedAt: time.Now().UnixMilli(), } - dbMsg, isPending := oc.dispatchOrQueue(promptCtx, msg.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) + dbMsg, isPending := oc.dispatchOrQueue(promptCtx, pending.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) return &bridgev2.MatrixMessageResponse{ DB: dbMsg, @@ -894,7 +875,7 @@ func (oc *AIClient) handleTextFileMessage( inboundCtx := oc.buildMatrixInboundContext(portal, msg.Event, combined, senderName, roomName, isGroup) promptCtx := withInboundContext(ctx, inboundCtx) - promptContext, err := oc.buildContextWithLinkContext(promptCtx, portal, meta, combined, nil, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(promptCtx, portal, meta, combined, nil, eventID) if err != nil { return nil, messageSendStatusError(err, "Couldn't prepare the message. Try again.", "") } @@ -915,7 +896,7 @@ func (oc *AIClient) handleTextFileMessage( } pending := pendingMessage{ - Event: msg.Event, + Event: snapshotPendingEvent(msg.Event), Portal: portal, Meta: meta, InboundContext: &inboundCtx, @@ -930,7 +911,7 @@ func (oc *AIClient) handleTextFileMessage( summaryLine: strings.TrimSpace(rawCaption), enqueuedAt: time.Now().UnixMilli(), } - dbMsg, isPending := oc.dispatchOrQueue(promptCtx, msg.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) + dbMsg, isPending := oc.dispatchOrQueue(promptCtx, pending.Event, portal, meta, userMessage, queueItem, queueSettings, promptContext) return &bridgev2.MatrixMessageResponse{ DB: dbMsg, @@ -1102,70 +1083,14 @@ func (oc *AIClient) buildContextForRegenerate( latestUserBody string, latestUserID id.EventID, ) (PromptContext, error) { - var promptContext PromptContext - bridgesdk.AppendChatMessagesToPromptContext(&promptContext.PromptContext, oc.buildSystemMessages(ctx, portal, meta)) - - historyLimit := oc.historyLimit(ctx, portal, meta) - resetAt := int64(0) - if meta != nil { - resetAt = meta.SessionResetAt + base := PromptContext{ + SystemPrompt: oc.buildConversationSystemPromptText(ctx, portal, meta, false), } - if historyLimit > 0 { - history, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(ctx, portal.PortalKey, historyLimit+2) - if err != nil { - return PromptContext{}, fmt.Errorf("failed to load prompt history: %w", err) - } - - // Determine whether to inject images into history (requires vision-capable model). - hasVision := oc.getModelCapabilitiesForMeta(ctx, meta).SupportsVision - historyBundles := make([][]PromptMessage, 0, len(history)) - - // Skip the most recent messages (last user and assistant) and build from older history - skippedUser := false - skippedAssistant := false - includedCount := 0 - for _, msg := range history { - msgMeta := messageMeta(msg) - // Skip commands and non-conversation messages - if !shouldIncludeInHistory(msgMeta) { - continue - } - if resetAt > 0 && msg.Timestamp.UnixMilli() < resetAt { - continue - } - - // Skip the last user message and last assistant message - if !skippedUser && msgMeta.Role == "user" { - skippedUser = true - continue - } - if !skippedAssistant && msgMeta.Role == "assistant" { - skippedAssistant = true - continue - } - - // Only inject images for recent messages and vision-capable models. - // This loop builds newest-to-oldest, so early entries are the most recent. - injectImages := hasVision && includedCount < maxHistoryImageMessages - includedCount++ - bundle := oc.historyMessageBundle(ctx, msgMeta, injectImages) - if len(bundle) > 0 { - historyBundles = append(historyBundles, bundle) - } - } - - for i := len(historyBundles) - 1; i >= 0; i-- { - promptContext.Messages = append(promptContext.Messages, historyBundles[i]...) - } + historyMessages, err := oc.replayHistoryMessages(ctx, portal, meta, historyReplayOptions{mode: historyReplayRegen}) + if err != nil { + return PromptContext{}, err } - - latest := latestUserBody - promptContext.Messages = append(promptContext.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: latest, - }}, - }) - return promptContext, nil + base.Messages = append(base.Messages, historyMessages...) + base.Messages = append(base.Messages, newUserTextPromptMessage(latestUserBody)) + return base, nil } diff --git a/bridges/ai/handlematrix_edit_test.go b/bridges/ai/handlematrix_edit_test.go new file mode 100644 index 00000000..d71dc0b2 --- /dev/null +++ b/bridges/ai/handlematrix_edit_test.go @@ -0,0 +1,61 @@ +package ai + +import ( + "context" + "testing" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/event" +) + +func TestHandleMatrixEdit_ModelRoomRejectsEdits(t *testing.T) { + oc := &AIClient{} + edit := &bridgev2.MatrixEdit{ + MatrixEventBase: bridgev2.MatrixEventBase[*event.MessageEventContent]{ + Portal: &bridgev2.Portal{ + Portal: &database.Portal{ + OtherUserID: modelUserID("openai/gpt-5"), + Metadata: modelModeTestMeta("openai/gpt-5"), + }, + }, + Content: &event.MessageEventContent{Body: "updated"}, + }, + EditTarget: &database.Message{}, + } + + err := oc.HandleMatrixEdit(context.Background(), edit) + if err == nil { + t.Fatal("expected model room edit to be rejected") + } + if err.Error() != bridgev2.ErrEditsNotSupportedInPortal.Error() { + t.Fatalf("expected ErrEditsNotSupportedInPortal, got %v", err) + } +} + +func TestHandleMatrixEdit_AgentRoomStillUsesAgentPath(t *testing.T) { + oc := &AIClient{} + edit := &bridgev2.MatrixEdit{ + MatrixEventBase: bridgev2.MatrixEventBase[*event.MessageEventContent]{ + Portal: &bridgev2.Portal{ + Portal: &database.Portal{ + OtherUserID: agentUserID("beeper"), + Metadata: agentModeTestMeta("beeper"), + }, + }, + Content: &event.MessageEventContent{Body: " "}, + }, + EditTarget: &database.Message{}, + } + + err := oc.HandleMatrixEdit(context.Background(), edit) + if err == nil { + t.Fatal("expected agent edit to continue into the existing handler path") + } + if err.Error() == bridgev2.ErrEditsNotSupportedInPortal.Error() { + t.Fatalf("expected agent room edit to avoid model-room rejection, got %v", err) + } + if err.Error() != "empty edit body" { + t.Fatalf("expected empty edit body error from existing path, got %v", err) + } +} diff --git a/bridges/ai/heartbeat_execute.go b/bridges/ai/heartbeat_execute.go index d1297bec..cb1e37a2 100644 --- a/bridges/ai/heartbeat_execute.go +++ b/bridges/ai/heartbeat_execute.go @@ -79,7 +79,7 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, sessionResolution := oc.resolveHeartbeatSession(agentID, heartbeat) storeKey := strings.TrimSpace(sessionResolution.SessionKey) - sessionPortal, sessionKey, err := oc.resolveHeartbeatSessionPortal(agentID, heartbeat) + sessionPortal, sessionKey, err := oc.resolveHeartbeatSessionPortal(agentID, heartbeat, sessionResolution) if err != nil || sessionPortal == nil || sessionPortal.MXID == "" { oc.log.Warn().Str("agent_id", agentID).Err(err).Msg("Heartbeat skipped: no session portal") return heartbeatRunResult{Status: "skipped", Reason: "no-session"} @@ -179,22 +179,10 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, } } - promptContext, err := oc.buildContextWithHeartbeat(context.Background(), sessionPortal, promptMeta, prompt) + promptContext, err := oc.buildHeartbeatTurnContext(context.Background(), sessionPortal, promptMeta, prompt) if err != nil { oc.log.Warn().Str("agent_id", agentID).Str("reason", reason).Err(err).Msg("Heartbeat failed to build prompt") - indicator := (*HeartbeatIndicatorType)(nil) - if hbCfg.UseIndicator { - indicator = resolveIndicatorType("failed") - } - oc.emitHeartbeatEvent(&HeartbeatEventPayload{ - TS: time.Now().UnixMilli(), - Status: "failed", - Reason: err.Error(), - Channel: hbCfg.Channel, - To: hbCfg.TargetRoom.String(), - DurationMs: time.Now().UnixMilli() - startedAtMs, - IndicatorType: indicator, - }) + oc.emitHeartbeatFailure(hbCfg, startedAtMs, err.Error()) return heartbeatRunResult{Status: "failed", Reason: err.Error()} } @@ -228,39 +216,31 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, return heartbeatRunResult{Status: res.Status, Reason: res.Reason} case <-done: oc.log.Warn().Str("agent_id", agentID).Msg("Heartbeat failed: stream completed without outcome") - indicator := (*HeartbeatIndicatorType)(nil) - if hbCfg.UseIndicator { - indicator = resolveIndicatorType("failed") - } - oc.emitHeartbeatEvent(&HeartbeatEventPayload{ - TS: time.Now().UnixMilli(), - Status: "failed", - Reason: "stream-finished-without-outcome", - Channel: hbCfg.Channel, - To: hbCfg.TargetRoom.String(), - DurationMs: time.Now().UnixMilli() - startedAtMs, - IndicatorType: indicator, - }) + oc.emitHeartbeatFailure(hbCfg, startedAtMs, "stream-finished-without-outcome") return heartbeatRunResult{Status: "failed", Reason: "heartbeat failed"} case <-timeoutCtx.Done(): oc.log.Warn().Str("agent_id", agentID).Msg("Heartbeat timed out after 2 minutes") - indicator := (*HeartbeatIndicatorType)(nil) - if hbCfg.UseIndicator { - indicator = resolveIndicatorType("failed") - } - oc.emitHeartbeatEvent(&HeartbeatEventPayload{ - TS: time.Now().UnixMilli(), - Status: "failed", - Reason: "timeout", - Channel: hbCfg.Channel, - To: hbCfg.TargetRoom.String(), - DurationMs: time.Now().UnixMilli() - startedAtMs, - IndicatorType: indicator, - }) + oc.emitHeartbeatFailure(hbCfg, startedAtMs, "timeout") return heartbeatRunResult{Status: "failed", Reason: "heartbeat timed out"} } } +func (oc *AIClient) emitHeartbeatFailure(hbCfg *HeartbeatRunConfig, startedAtMs int64, reason string) { + indicator := (*HeartbeatIndicatorType)(nil) + if hbCfg.UseIndicator { + indicator = resolveIndicatorType("failed") + } + oc.emitHeartbeatEvent(&HeartbeatEventPayload{ + TS: time.Now().UnixMilli(), + Status: "failed", + Reason: reason, + Channel: hbCfg.Channel, + To: hbCfg.TargetRoom.String(), + DurationMs: time.Now().UnixMilli() - startedAtMs, + IndicatorType: indicator, + }) +} + func drainHeartbeatSystemEvents(ownerKey string, primaryKey string, secondaryKey string) []SystemEvent { entries := drainSystemEventEntries(ownerKey, primaryKey) if sk := strings.TrimSpace(secondaryKey); sk != "" && !strings.EqualFold(strings.TrimSpace(primaryKey), sk) { @@ -282,22 +262,13 @@ func systemEventsOwnerKey(oc *AIClient) string { return string(oc.UserLogin.Bridge.DB.BridgeID) + "|" + string(oc.UserLogin.ID) } -func (oc *AIClient) buildContextWithHeartbeat(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, prompt string) (PromptContext, error) { - base, err := oc.buildBaseContext(ctx, portal, meta) - if err != nil { - return PromptContext{}, err - } - base.Messages = append(base.Messages, PromptMessage{ - Role: PromptRoleUser, - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: prompt, - }}, - }) - return base, nil -} - -func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *HeartbeatConfig) (*bridgev2.Portal, string, error) { +func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *HeartbeatConfig, preResolved ...heartbeatSessionResolution) (*bridgev2.Portal, string, error) { + var hbSession heartbeatSessionResolution + if len(preResolved) > 0 && preResolved[0].SessionKey != "" { + hbSession = preResolved[0] + } else { + hbSession = oc.resolveHeartbeatSession(agentID, heartbeat) + } session := "" if heartbeat != nil && heartbeat.Session != nil { session = strings.TrimSpace(*heartbeat.Session) @@ -307,7 +278,6 @@ func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *Hea mainKey = strings.TrimSpace(oc.connector.Config.Session.MainKey) } if session == "" || strings.EqualFold(session, "main") || strings.EqualFold(session, "global") || (mainKey != "" && strings.EqualFold(session, mainKey)) { - hbSession := oc.resolveHeartbeatSession(agentID, heartbeat) if portal := oc.heartbeatSessionPortalCandidate(agentID, hbSession); portal != nil { return portal, portal.MXID.String(), nil } @@ -326,7 +296,6 @@ func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *Hea } } } - hbSession := oc.resolveHeartbeatSession(agentID, heartbeat) if portal := oc.heartbeatSessionPortalCandidate(agentID, hbSession); portal != nil { return portal, portal.MXID.String(), nil } diff --git a/bridges/ai/identifiers.go b/bridges/ai/identifiers.go index 53abbfe6..f28b506e 100644 --- a/bridges/ai/identifiers.go +++ b/bridges/ai/identifiers.go @@ -200,11 +200,7 @@ func shouldIncludeInHistory(meta *MessageMetadata) bool { if meta.Role != "user" && meta.Role != "assistant" { return false } - return len(meta.CanonicalTurnData) > 0 || - strings.TrimSpace(meta.Body) != "" || - len(meta.ToolCalls) > 0 || - strings.TrimSpace(meta.MediaURL) != "" || - len(meta.GeneratedFiles) > 0 + return len(meta.CanonicalTurnData) > 0 } func loginMetadata(login *bridgev2.UserLogin) *UserLoginMetadata { diff --git a/bridges/ai/image_generation_tool.go b/bridges/ai/image_generation_tool.go index 5afa2d1e..79e2cda5 100644 --- a/bridges/ai/image_generation_tool.go +++ b/bridges/ai/image_generation_tool.go @@ -153,6 +153,9 @@ func readStringSlice(args map[string]any, key string) []string { } func resolveImageGenProvider(req imageGenRequest, btc *BridgeToolContext) (imageGenProvider, error) { + if btc == nil || btc.Client == nil || btc.Client.UserLogin == nil { + return "", errors.New("image generation is not available for this login") + } provider := strings.ToLower(strings.TrimSpace(req.Provider)) if provider != "" { switch provider { @@ -176,12 +179,38 @@ func resolveImageGenProvider(req imageGenRequest, btc *BridgeToolContext) (image } } + loginMeta := loginMetadata(btc.Client.UserLogin) + if loginMeta == nil { + return "", errors.New("image generation is not available for this login") + } + inferredProvider := inferProviderFromModel(req.Model) + if inferredProvider != "" { + switch inferredProvider { + case imageGenProviderOpenAI: + if supportsOpenAIImageGen(btc) { + return imageGenProviderOpenAI, nil + } + case imageGenProviderGemini: + if supportsGeminiImageGen(btc) { + return imageGenProviderGemini, nil + } + case imageGenProviderOpenRouter: + if supportsOpenRouterImageGen(btc) { + return imageGenProviderOpenRouter, nil + } + } + // Magic Proxy only exposes the OpenAI images route in practice, so use + // that when a requested image model belongs to an unavailable surface. + if loginMeta != nil && loginMeta.Provider == ProviderMagicProxy && supportsOpenAIImageGen(btc) { + return imageGenProviderOpenAI, nil + } + } + // Prefer OpenRouter image gen whenever it's available (Gemini models support extra controls). if supportsOpenRouterImageGen(btc) { return imageGenProviderOpenRouter, nil } - loginMeta := loginMetadata(btc.Client.UserLogin) switch loginMeta.Provider { case ProviderOpenAI: if !supportsOpenAIImageGen(btc) { @@ -228,11 +257,14 @@ func supportsOpenAIImageGen(btc *BridgeToolContext) bool { return false } loginMeta := loginMetadata(btc.Client.UserLogin) + if loginMeta == nil { + return false + } switch loginMeta.Provider { case ProviderOpenAI, ProviderMagicProxy: if loginMeta.Provider == ProviderMagicProxy { // Magic Proxy uses a per-login token+base URL, not the OpenAI config key. - return strings.TrimSpace(loginMeta.APIKey) != "" && strings.TrimSpace(loginMeta.BaseURL) != "" + return loginCredentialAPIKey(loginMeta) != "" && loginCredentialBaseURL(loginMeta) != "" } return btc.Client.connector.resolveOpenAIAPIKey(loginMeta) != "" default: @@ -255,14 +287,8 @@ func supportsGeminiImageGen(btc *BridgeToolContext) bool { } switch loginMeta.Provider { case ProviderMagicProxy: - if btc.Client.connector != nil { - services := btc.Client.connector.resolveServiceConfig(loginMeta) - if svc, ok := services[serviceGemini]; ok { - return strings.TrimSpace(svc.BaseURL) != "" && strings.TrimSpace(svc.APIKey) != "" - } - } - base := normalizeProxyBaseURL(loginMeta.BaseURL) - return base != "" && strings.TrimSpace(loginMeta.APIKey) != "" + // Magic Proxy does not expose the Gemini image generation endpoint. + return false default: return false } @@ -273,12 +299,20 @@ func normalizeOpenAIModel(model string) string { if model == "" { return defaultOpenAIImageModel } + switch inferProviderFromModel(model) { + case imageGenProviderGemini, imageGenProviderOpenRouter: + return defaultOpenAIImageModel + } _, actual := ParseModelPrefix(model) actual = strings.TrimSpace(actual) actual = strings.TrimPrefix(actual, "openai/") if actual == "" { return model } + switch strings.ToLower(actual) { + case "gpt-5-image", "gpt-5-image-mini": + return defaultOpenAIImageModel + } return actual } @@ -444,7 +478,13 @@ func isAllowedValue(value string, allowed map[string]bool) bool { } func buildOpenAIImagesBaseURL(btc *BridgeToolContext) (string, error) { + if btc == nil || btc.Client == nil || btc.Client.UserLogin == nil || btc.Client.UserLogin.Metadata == nil { + return "", errors.New("openai image generation not available for this provider") + } loginMeta := loginMetadata(btc.Client.UserLogin) + if loginMeta == nil { + return "", errors.New("openai image generation not available for this provider") + } switch loginMeta.Provider { case ProviderOpenAI: base := btc.Client.connector.resolveOpenAIBaseURL() @@ -456,7 +496,7 @@ func buildOpenAIImagesBaseURL(btc *BridgeToolContext) (string, error) { return strings.TrimSuffix(strings.TrimSpace(svc.BaseURL), "/"), nil } } - base := normalizeProxyBaseURL(loginMeta.BaseURL) + base := normalizeProxyBaseURL(loginCredentialBaseURL(loginMeta)) if base == "" { return "", errors.New("magic proxy base_url is required for image generation") } @@ -482,7 +522,7 @@ func buildGeminiBaseURL(btc *BridgeToolContext) (string, error) { return strings.TrimSuffix(strings.TrimSpace(svc.BaseURL), "/"), nil } } - base := normalizeProxyBaseURL(loginMeta.BaseURL) + base := normalizeProxyBaseURL(loginCredentialBaseURL(loginMeta)) if base == "" { return "", errors.New("magic proxy base_url is required for image generation") } @@ -592,12 +632,9 @@ func resolveOpenRouterImageGenEndpoint(btc *BridgeToolContext) (baseURL string, // Provider-specific per-login endpoints. switch meta.Provider { case ProviderMagicProxy: - base := normalizeProxyBaseURL(meta.BaseURL) - key := trim(meta.APIKey) - if base == "" || key == "" { - return "", "", false - } - return joinProxyPath(base, "/openrouter/v1"), key, true + // Magic Proxy does not expose the OpenRouter images endpoint; use the + // verified OpenAI images route instead. + return "", "", false case ProviderOpenRouter: if conn == nil { return "", "", false diff --git a/bridges/ai/image_generation_tool_magic_proxy_test.go b/bridges/ai/image_generation_tool_magic_proxy_test.go index a130f3db..7cf53751 100644 --- a/bridges/ai/image_generation_tool_magic_proxy_test.go +++ b/bridges/ai/image_generation_tool_magic_proxy_test.go @@ -1,12 +1,16 @@ package ai -import "testing" +import ( + "testing" +) -func TestResolveImageGenProviderMagicProxyPrefersOpenRouterForSimplePrompts(t *testing.T) { +func TestResolveImageGenProviderMagicProxyPrefersOpenAIForSimplePrompts(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -17,16 +21,18 @@ func TestResolveImageGenProviderMagicProxyPrefersOpenRouterForSimplePrompts(t *t if err != nil { t.Fatalf("resolveImageGenProvider returned error: %v", err) } - if got != imageGenProviderOpenRouter { - t.Fatalf("expected provider %q, got %q", imageGenProviderOpenRouter, got) + if got != imageGenProviderOpenAI { + t.Fatalf("expected provider %q, got %q", imageGenProviderOpenAI, got) } } -func TestResolveImageGenProviderMagicProxyStillPrefersOpenRouterWhenCountIsGreaterThanOne(t *testing.T) { +func TestResolveImageGenProviderMagicProxyStillPrefersOpenAIWhenCountIsGreaterThanOne(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -37,16 +43,18 @@ func TestResolveImageGenProviderMagicProxyStillPrefersOpenRouterWhenCountIsGreat if err != nil { t.Fatalf("resolveImageGenProvider returned error: %v", err) } - if got != imageGenProviderOpenRouter { - t.Fatalf("expected provider %q, got %q", imageGenProviderOpenRouter, got) + if got != imageGenProviderOpenAI { + t.Fatalf("expected provider %q, got %q", imageGenProviderOpenAI, got) } } -func TestResolveImageGenProviderMagicProxyProviderOpenAIStillRoutesToOpenRouter(t *testing.T) { +func TestResolveImageGenProviderMagicProxyProviderOpenAIUsesOpenAI(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -63,32 +71,77 @@ func TestResolveImageGenProviderMagicProxyProviderOpenAIStillRoutesToOpenRouter( } } -func TestResolveImageGenProviderMagicProxyProviderGeminiUsesGemini(t *testing.T) { +func TestResolveImageGenProviderMagicProxyModelHintFallsBackToOpenAI(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) got, err := resolveImageGenProvider(imageGenRequest{ + Model: "google/gemini-3-pro-image-preview", + Prompt: "cat", + Count: 1, + }, btc) + if err != nil { + t.Fatalf("resolveImageGenProvider returned error: %v", err) + } + if got != imageGenProviderOpenAI { + t.Fatalf("expected provider %q, got %q", imageGenProviderOpenAI, got) + } +} + +func TestResolveImageGenProviderMagicProxyProviderGeminiIsUnavailable(t *testing.T) { + meta := &UserLoginMetadata{ + Provider: ProviderMagicProxy, + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, + } + btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) + + _, err := resolveImageGenProvider(imageGenRequest{ Provider: "gemini", Prompt: "cat", Count: 1, }, btc) - if err != nil { - t.Fatalf("resolveImageGenProvider returned error: %v", err) + if err == nil { + t.Fatal("expected gemini image generation to be unavailable for magic proxy") } - if got != imageGenProviderGemini { - t.Fatalf("expected provider %q, got %q", imageGenProviderGemini, got) +} + +func TestNormalizeOpenAIModelMapsUnavailableAliasesToGPTImage1(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {name: "empty", input: "", want: "gpt-image-1"}, + {name: "prefixed alias", input: "openai/gpt-5-image", want: "gpt-image-1"}, + {name: "mini alias", input: "gpt-5-image-mini", want: "gpt-image-1"}, + {name: "gemini alias", input: "google/gemini-3-pro-image-preview", want: "gpt-image-1"}, + {name: "native openai", input: "gpt-image-1", want: "gpt-image-1"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := normalizeOpenAIModel(tc.input); got != tc.want { + t.Fatalf("normalizeOpenAIModel(%q) = %q, want %q", tc.input, got, tc.want) + } + }) } } func TestBuildOpenAIImagesBaseURLMagicProxy(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -104,8 +157,10 @@ func TestBuildOpenAIImagesBaseURLMagicProxy(t *testing.T) { func TestBuildGeminiBaseURLMagicProxy(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -117,3 +172,13 @@ func TestBuildGeminiBaseURLMagicProxy(t *testing.T) { t.Fatalf("unexpected base url: %q", baseURL) } } + +func TestResolveImageGenProviderRejectsMissingLoginMetadata(t *testing.T) { + btc := &BridgeToolContext{ + Client: &AIClient{}, + } + + if _, err := resolveImageGenProvider(imageGenRequest{Prompt: "cat"}, btc); err == nil { + t.Fatal("expected missing login metadata to be rejected") + } +} diff --git a/bridges/ai/image_understanding.go b/bridges/ai/image_understanding.go index 49dd9bf2..83fa17f3 100644 --- a/bridges/ai/image_understanding.go +++ b/bridges/ai/image_understanding.go @@ -7,8 +7,6 @@ import ( "strings" "maunium.net/go/mautrix/event" - - bridgesdk "github.com/beeper/agentremote/sdk" ) func (oc *AIClient) canUseMediaUnderstanding(meta *PortalMetadata) bool { @@ -147,17 +145,6 @@ func (oc *AIClient) resolveVisionModelForImage(ctx context.Context, meta *Portal ) } -// resolveAudioUnderstandingModel returns an audio-capable model from the agent's model chain. -func (oc *AIClient) resolveAudioUnderstandingModel(ctx context.Context, meta *PortalMetadata) string { - return oc.resolveUnderstandingModel( - ctx, - meta, - func(caps ModelCapabilities) bool { return caps.SupportsAudio }, - func(info ModelInfo) bool { return info.SupportsAudio }, - "audio", - ) -} - func (oc *AIClient) pickModelFromCache(cache *ModelCache, provider string, supports modelInfoFilter) string { if cache == nil || len(cache.Models) == 0 { return "" @@ -224,9 +211,9 @@ func (oc *AIClient) analyzeImageWithModel( actualMimeType = "image/jpeg" } - dataURL := bridgesdk.BuildDataURL(actualMimeType, b64Data) + dataURL := BuildDataURL(actualMimeType, b64Data) - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( + ctxPrompt := UserPromptContext( PromptBlock{ Type: PromptBlockImage, ImageURL: dataURL, @@ -236,7 +223,7 @@ func (oc *AIClient) analyzeImageWithModel( Type: PromptBlockText, Text: prompt, }, - )} + ) resp, err := oc.provider.Generate(ctx, GenerateParams{ Model: modelIDForAPI, @@ -249,67 +236,6 @@ func (oc *AIClient) analyzeImageWithModel( return strings.TrimSpace(resp.Content), nil } - -func (oc *AIClient) analyzeAudioWithModel( - ctx context.Context, - modelID string, - audioURL string, - mimeType string, - encryptedFile *event.EncryptedFileInfo, - prompt string, -) (string, error) { - if strings.TrimSpace(modelID) == "" { - return "", errors.New("missing model for audio analysis") - } - if strings.TrimSpace(prompt) == "" { - prompt = defaultPromptByCapability[MediaCapabilityAudio] - } - - modelIDForAPI := oc.modelIDForAPI(modelID) - audioRef := mediaSourceLabel(audioURL, encryptedFile) - b64Data, actualMimeType, err := oc.downloadMediaBase64(ctx, audioURL, encryptedFile, 25, mimeType) - if err != nil { - return "", fmt.Errorf("failed to download audio %s for model %s: %w", audioRef, modelIDForAPI, err) - } - actualMimeType = strings.TrimSpace(actualMimeType) - if actualMimeType == "" { - actualMimeType = strings.TrimSpace(mimeType) - } - format := getAudioFormat(actualMimeType) - if format == "" { - format = "mp3" - } - - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( - PromptBlock{ - Type: PromptBlockAudio, - AudioB64: b64Data, - AudioFormat: format, - }, - PromptBlock{ - Type: PromptBlockText, - Text: prompt, - }, - )} - - params := GenerateParams{ - Model: modelIDForAPI, - Context: ctxPrompt, - MaxCompletionTokens: defaultImageUnderstandingLimit, - } - var resp *GenerateResponse - if provider, ok := oc.provider.(*OpenAIProvider); ok { - resp, err = provider.generateChatCompletions(ctx, params) - } else { - resp, err = oc.provider.Generate(ctx, params) - } - if err != nil { - return "", fmt.Errorf("audio analysis failed for model %s (audio %s): %w", modelIDForAPI, audioRef, err) - } - - return strings.TrimSpace(resp.Content), nil -} - func mediaSourceLabel(mediaURL string, encryptedFile *event.EncryptedFileInfo) string { source := strings.TrimSpace(mediaURL) if encryptedFile != nil && encryptedFile.URL != "" { diff --git a/bridges/ai/integration_host.go b/bridges/ai/integration_host.go index 90325d25..1f6c25aa 100644 --- a/bridges/ai/integration_host.go +++ b/bridges/ai/integration_host.go @@ -9,6 +9,7 @@ import ( "github.com/openai/openai-go/v3" "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -120,7 +121,7 @@ func (h *runtimeIntegrationHost) AgentModuleConfig(agentID string, module string // ---- Host methods: logger access ---- -func (h *runtimeIntegrationHost) RawLogger() any { +func (h *runtimeIntegrationHost) RawLogger() zerolog.Logger { if h == nil || h.client == nil { return zerolog.Logger{} } @@ -129,7 +130,7 @@ func (h *runtimeIntegrationHost) RawLogger() any { // ---- Host methods: portal management ---- -func (h *runtimeIntegrationHost) GetOrCreatePortal(ctx context.Context, portalID string, receiver string, displayName string, setupMeta func(meta any)) (portal any, roomID string, err error) { +func (h *runtimeIntegrationHost) GetOrCreatePortal(ctx context.Context, portalID string, receiver string, displayName string, setupMeta func(meta *PortalMetadata)) (portal *bridgev2.Portal, roomID string, err error) { if h == nil || h.client == nil || h.client.UserLogin == nil { return nil, "", fmt.Errorf("missing login") } @@ -155,122 +156,55 @@ func (h *runtimeIntegrationHost) GetOrCreatePortal(ctx context.Context, portalID return p, p.MXID.String(), nil } -func (h *runtimeIntegrationHost) SavePortal(ctx context.Context, portal any, reason string) error { +func (h *runtimeIntegrationHost) SavePortal(ctx context.Context, portal *bridgev2.Portal, reason string) error { if h == nil || h.client == nil { return nil } - p, _ := portal.(*bridgev2.Portal) - if p == nil { + if portal == nil { return nil } - h.client.savePortalQuiet(ctx, p, reason) + h.client.savePortalQuiet(ctx, portal, reason) return nil } -func (h *runtimeIntegrationHost) PortalRoomID(portal any) string { - p, _ := portal.(*bridgev2.Portal) - if p == nil { +func (h *runtimeIntegrationHost) PortalRoomID(portal *bridgev2.Portal) string { + if portal == nil { return "" } - return p.MXID.String() + return portal.MXID.String() } -func (h *runtimeIntegrationHost) PortalKeyString(portal any) string { - p, _ := portal.(*bridgev2.Portal) - if p == nil { +func (h *runtimeIntegrationHost) PortalKeyString(portal *bridgev2.Portal) string { + if portal == nil { return "" } - return p.PortalKey.String() + return portal.PortalKey.String() } -// ---- Host methods: metadata access ---- - -func (h *runtimeIntegrationHost) GetModuleMeta(meta any, key string) any { - m, _ := meta.(*PortalMetadata) - if m == nil || m.ModuleMeta == nil { - return nil - } - return m.ModuleMeta[key] -} - -func (h *runtimeIntegrationHost) SetModuleMeta(meta any, key string, value any) { - m, _ := meta.(*PortalMetadata) - if m == nil { - return - } - m.SetModuleMeta(key, value) -} - -func (h *runtimeIntegrationHost) AgentIDFromMeta(meta any) string { - m, _ := meta.(*PortalMetadata) - return resolveAgentID(m) -} - -func (h *runtimeIntegrationHost) CompactionCount(meta any) int { - m, _ := meta.(*PortalMetadata) - if m == nil { - return 0 - } - return m.CompactionCount -} - -func (h *runtimeIntegrationHost) IsGroupChat(ctx context.Context, portal any) bool { +func (h *runtimeIntegrationHost) IsGroupChat(ctx context.Context, portal *bridgev2.Portal) bool { if h == nil || h.client == nil { return false } - p, _ := portal.(*bridgev2.Portal) - if p == nil { + if portal == nil { return false } - return h.client.isGroupChat(ctx, p) -} - -func (h *runtimeIntegrationHost) IsInternalRoom(meta any) bool { - m, _ := meta.(*PortalMetadata) - if m == nil { - return false - } - return isModuleInternalRoom(m) -} - -func (h *runtimeIntegrationHost) PortalMeta(portal any) any { - p, _ := portal.(*bridgev2.Portal) - return portalMeta(p) -} - -func (h *runtimeIntegrationHost) CloneMeta(portal any) any { - p, _ := portal.(*bridgev2.Portal) - return clonePortalMetadata(portalMeta(p)) -} - -func (h *runtimeIntegrationHost) SetMetaField(meta any, key string, value any) { - m, _ := meta.(*PortalMetadata) - if m == nil { - return - } - switch key { - case "DisabledTools": - if v, ok := value.([]string); ok { - m.DisabledTools = v - } - } + return h.client.isGroupChat(ctx, portal) } // ---- Host methods: message helpers ---- -func (h *runtimeIntegrationHost) RecentMessages(ctx context.Context, portal any, count int) []integrationruntime.MessageSummary { +func (h *runtimeIntegrationHost) RecentMessages(ctx context.Context, portal *bridgev2.Portal, count int) []integrationruntime.MessageSummary { if h == nil || h.client == nil { return nil } - p, _ := portal.(*bridgev2.Portal) - if p == nil || count <= 0 || h.client.UserLogin == nil || h.client.UserLogin.Bridge == nil || h.client.UserLogin.Bridge.DB == nil { + if portal == nil || count <= 0 || h.client.UserLogin == nil || h.client.UserLogin.Bridge == nil || h.client.UserLogin.Bridge.DB == nil { return nil } maxMessages := count if maxMessages > 10 { maxMessages = 10 } - history, err := h.client.UserLogin.Bridge.DB.Message.GetLastNInPortal(h.client.backgroundContext(ctx), p.PortalKey, maxMessages) + history, err := h.client.UserLogin.Bridge.DB.Message.GetLastNInPortal(h.client.backgroundContext(ctx), portal.PortalKey, maxMessages) if err != nil || len(history) == 0 { return nil } @@ -293,20 +227,18 @@ func (h *runtimeIntegrationHost) RecentMessages(ctx context.Context, portal any, return out } -func (h *runtimeIntegrationHost) LastAssistantMessage(ctx context.Context, portal any) (id string, timestamp int64) { +func (h *runtimeIntegrationHost) LastAssistantMessage(ctx context.Context, portal *bridgev2.Portal) (id string, timestamp int64) { if h == nil || h.client == nil { return "", 0 } - p, _ := portal.(*bridgev2.Portal) - return h.client.lastAssistantMessageInfo(ctx, p) + return h.client.lastAssistantMessageInfo(ctx, portal) } -func (h *runtimeIntegrationHost) WaitForAssistantMessage(ctx context.Context, portal any, afterID string, afterTS int64) (*integrationruntime.AssistantMessageInfo, bool) { +func (h *runtimeIntegrationHost) WaitForAssistantMessage(ctx context.Context, portal *bridgev2.Portal, afterID string, afterTS int64) (*integrationruntime.AssistantMessageInfo, bool) { if h == nil || h.client == nil { return nil, false } - p, _ := portal.(*bridgev2.Portal) - msg, found := h.client.waitForNewAssistantMessage(ctx, p, afterID, afterTS) + msg, found := h.client.waitForNewAssistantMessage(ctx, portal, afterID, afterTS) if !found || msg == nil { return nil, false } @@ -331,7 +263,7 @@ func (h *runtimeIntegrationHost) RunHeartbeatOnce(ctx context.Context, reason st return h.client.scheduler.RunHeartbeatSweep(ctx, reason) } -func (h *runtimeIntegrationHost) ResolveHeartbeatSessionPortal(agentID string) (portal any, sessionKey string, err error) { +func (h *runtimeIntegrationHost) ResolveHeartbeatSessionPortal(agentID string) (portal *bridgev2.Portal, sessionKey string, err error) { if h == nil || h.client == nil { return nil, "", fmt.Errorf("missing client") } @@ -450,7 +382,7 @@ func (h *runtimeIntegrationHost) NormalizeThinkingLevel(raw string) (string, boo // ---- Host methods: model helpers ---- -func (h *runtimeIntegrationHost) EffectiveModel(meta any) string { +func (h *runtimeIntegrationHost) EffectiveModel(meta integrationruntime.Meta) string { if h == nil || h.client == nil { return "" } @@ -458,7 +390,7 @@ func (h *runtimeIntegrationHost) EffectiveModel(meta any) string { return h.client.effectiveModel(m) } -func (h *runtimeIntegrationHost) ContextWindow(meta any) int { +func (h *runtimeIntegrationHost) ContextWindow(meta integrationruntime.Meta) int { if h == nil || h.client == nil { return 0 } @@ -509,15 +441,14 @@ func (h *runtimeIntegrationHost) BackgroundContext(ctx context.Context) context. // ---- Host methods: chat completions ---- -func (h *runtimeIntegrationHost) NewCompletion(ctx context.Context, model string, messages []openai.ChatCompletionMessageParamUnion, toolParams any) (*integrationruntime.CompletionResult, error) { +func (h *runtimeIntegrationHost) NewCompletion(ctx context.Context, model string, messages []openai.ChatCompletionMessageParamUnion, toolParams []openai.ChatCompletionToolUnionParam) (*integrationruntime.CompletionResult, error) { if h == nil || h.client == nil { return nil, fmt.Errorf("missing client") } - params, _ := toolParams.([]openai.ChatCompletionToolUnionParam) req := openai.ChatCompletionNewParams{ Model: model, Messages: messages, - Tools: params, + Tools: toolParams, } resp, err := h.client.api.Chat.Completions.New(ctx, req) if err != nil { @@ -549,7 +480,7 @@ func (h *runtimeIntegrationHost) NewCompletion(ctx context.Context, model string // ---- Host methods: tool policy ---- -func (h *runtimeIntegrationHost) IsToolEnabled(meta any, toolName string) bool { +func (h *runtimeIntegrationHost) IsToolEnabled(meta integrationruntime.Meta, toolName string) bool { if h == nil || h.client == nil { return true } @@ -567,24 +498,23 @@ func (h *runtimeIntegrationHost) AllToolDefinitions() []integrationruntime.ToolD return out } -func (h *runtimeIntegrationHost) ExecuteToolInContext(ctx context.Context, portal any, meta any, name string, argsJSON string) (string, error) { +func (h *runtimeIntegrationHost) ExecuteToolInContext(ctx context.Context, portal *bridgev2.Portal, meta integrationruntime.Meta, name string, argsJSON string) (string, error) { if h == nil || h.client == nil { return "", fmt.Errorf("missing client") } - p, _ := portal.(*bridgev2.Portal) m, _ := meta.(*PortalMetadata) if m != nil && !h.client.isToolEnabled(m, name) { return "", fmt.Errorf("tool %s is disabled", name) } toolCtx := WithBridgeToolContext(ctx, &BridgeToolContext{ Client: h.client, - Portal: p, + Portal: portal, Meta: m, }) - return h.client.executeBuiltinTool(toolCtx, p, name, argsJSON) + return h.client.executeBuiltinTool(toolCtx, portal, name, argsJSON) } -func (h *runtimeIntegrationHost) ToolsToOpenAIParams(tools []integrationruntime.ToolDefinition) any { +func (h *runtimeIntegrationHost) ToolsToOpenAIParams(tools []integrationruntime.ToolDefinition) []openai.ChatCompletionToolUnionParam { if h == nil || h.client == nil { return nil } @@ -614,7 +544,7 @@ func (h *runtimeIntegrationHost) ReadTextFile(ctx context.Context, agentID strin return entry.Content, entry.Path, true, nil } -func (h *runtimeIntegrationHost) WriteTextFile(ctx context.Context, portal any, meta any, agentID string, mode string, path string, content string, maxBytes int) (finalPath string, err error) { +func (h *runtimeIntegrationHost) WriteTextFile(ctx context.Context, portal *bridgev2.Portal, meta integrationruntime.Meta, agentID string, mode string, path string, content string, maxBytes int) (finalPath string, err error) { if h == nil || h.client == nil { return "", fmt.Errorf("storage unavailable") } @@ -644,11 +574,10 @@ func (h *runtimeIntegrationHost) WriteTextFile(ctx context.Context, portal any, return "", e } if entry != nil { - p, _ := portal.(*bridgev2.Portal) m, _ := meta.(*PortalMetadata) toolCtx := WithBridgeToolContext(ctx, &BridgeToolContext{ Client: h.client, - Portal: p, + Portal: portal, Meta: m, }) notifyIntegrationFileChanged(toolCtx, entry.Path) @@ -766,7 +695,7 @@ func (h *runtimeIntegrationHost) SessionPortals(ctx context.Context, loginID str return out, nil } -func (h *runtimeIntegrationHost) LoginDB() any { +func (h *runtimeIntegrationHost) LoginDB() *dbutil.Database { if h == nil || h.client == nil { return nil } @@ -819,52 +748,49 @@ func (h *runtimeIntegrationHost) CronRun(ctx context.Context, jobID string) (boo // ---- Host methods: dispatch/lookup primitives ---- -func (h *runtimeIntegrationHost) ResolvePortalByRoomID(ctx context.Context, roomID string) any { +func (h *runtimeIntegrationHost) ResolvePortalByRoomID(ctx context.Context, roomID string) *bridgev2.Portal { if h == nil || h.client == nil || strings.TrimSpace(roomID) == "" { return nil } return h.client.portalByRoomID(ctx, portalRoomIDFromString(roomID)) } -func (h *runtimeIntegrationHost) ResolveDefaultPortal(ctx context.Context) any { +func (h *runtimeIntegrationHost) ResolveDefaultPortal(ctx context.Context) *bridgev2.Portal { if h == nil || h.client == nil { return nil } return h.client.defaultChatPortal() } -func (h *runtimeIntegrationHost) ResolveLastActivePortal(ctx context.Context, agentID string) any { +func (h *runtimeIntegrationHost) ResolveLastActivePortal(ctx context.Context, agentID string) *bridgev2.Portal { if h == nil || h.client == nil { return nil } return h.client.lastActivePortal(agentID) } -func (h *runtimeIntegrationHost) DispatchInternalMessage(ctx context.Context, portal any, meta any, message string, source string) error { +func (h *runtimeIntegrationHost) DispatchInternalMessage(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, message string, source string) error { if h == nil || h.client == nil { return fmt.Errorf("missing client") } - p, _ := portal.(*bridgev2.Portal) - if p == nil { + if portal == nil { return fmt.Errorf("missing portal") } - m, _ := meta.(*PortalMetadata) - if m == nil { - m = &PortalMetadata{} + if meta == nil { + meta = &PortalMetadata{} } - _, _, err := h.client.dispatchInternalMessage(ctx, p, m, message, source, false) + _, _, err := h.client.dispatchInternalMessage(ctx, portal, meta, message, source, false) return err } -func (h *runtimeIntegrationHost) SendAssistantMessage(ctx context.Context, portal any, body string) error { +func (h *runtimeIntegrationHost) SendAssistantMessage(ctx context.Context, portal *bridgev2.Portal, body string) error { if h == nil || h.client == nil { return fmt.Errorf("missing client") } - p, _ := portal.(*bridgev2.Portal) - if p == nil { + if portal == nil { return fmt.Errorf("missing portal") } - return h.client.sendPlainAssistantMessage(ctx, p, body) + return h.client.sendPlainAssistantMessage(ctx, portal, body) } func (h *runtimeIntegrationHost) RequestNow(ctx context.Context, reason string) { @@ -887,7 +813,7 @@ func (h *runtimeIntegrationHost) ExecuteBuiltinTool(ctx context.Context, scope i if h == nil || h.client == nil { return "", fmt.Errorf("missing client") } - portal, _ := scope.Portal.(*bridgev2.Portal) + portal := scope.Portal meta, _ := scope.Meta.(*PortalMetadata) if meta != nil && !h.client.isToolEnabled(meta, name) { return "", fmt.Errorf("tool %s is disabled", name) @@ -904,7 +830,7 @@ func (h *runtimeIntegrationHost) ResolveWorkspaceDir() string { return resolvePromptWorkspaceDir() } -func (h *runtimeIntegrationHost) BridgeDB() any { +func (h *runtimeIntegrationHost) BridgeDB() *dbutil.Database { if h == nil || h.client == nil { return nil } diff --git a/bridges/ai/integrations.go b/bridges/ai/integrations.go index fadb3283..69fceef8 100644 --- a/bridges/ai/integrations.go +++ b/bridges/ai/integrations.go @@ -9,7 +9,6 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "github.com/beeper/agentremote" @@ -80,46 +79,6 @@ func (r *toolIntegrationRegistry) availability( return false, false, SourceGlobalDefault, "" } -type promptIntegrationRegistry struct { - items []integrationruntime.PromptIntegration -} - -func (r *promptIntegrationRegistry) register(integration integrationruntime.PromptIntegration) { - if integration == nil { - return - } - r.items = append(r.items, integration) -} - -func (r *promptIntegrationRegistry) additionalMessages( - ctx context.Context, - scope integrationruntime.PromptScope, -) []openai.ChatCompletionMessageParamUnion { - if r == nil { - return nil - } - var out []openai.ChatCompletionMessageParamUnion - for _, integration := range r.items { - out = append(out, integration.AdditionalSystemMessages(ctx, scope)...) - } - return out -} - -func (r *promptIntegrationRegistry) augmentPrompt( - ctx context.Context, - scope integrationruntime.PromptScope, - prompt []openai.ChatCompletionMessageParamUnion, -) []openai.ChatCompletionMessageParamUnion { - if r == nil { - return prompt - } - out := prompt - for _, integration := range r.items { - out = integration.AugmentPrompt(ctx, scope, out) - } - return out -} - type commandIntegrationRegistration struct { integration integrationruntime.CommandIntegration definition integrationruntime.CommandDefinition @@ -260,26 +219,15 @@ func settingSourceFromIntegration(source integrationruntime.SettingSource) Setti func (oc *AIClient) toolScope(portal *bridgev2.Portal, meta *PortalMetadata) integrationruntime.ToolScope { return integrationruntime.ToolScope{ - Client: oc, Portal: portal, Meta: meta, } } -func (oc *AIClient) promptScope(portal *bridgev2.Portal, meta *PortalMetadata) integrationruntime.PromptScope { - return integrationruntime.PromptScope{ - Client: oc, - Portal: portal, - Meta: meta, - } -} - -func (oc *AIClient) commandScope(portal *bridgev2.Portal, meta *PortalMetadata, evt any) integrationruntime.CommandScope { +func (oc *AIClient) commandScope(portal *bridgev2.Portal, meta *PortalMetadata) integrationruntime.CommandScope { return integrationruntime.CommandScope{ - Client: oc, Portal: portal, Meta: meta, - Event: evt, } } @@ -288,12 +236,11 @@ func (oc *AIClient) initIntegrations() { return } oc.toolRegistry = &toolIntegrationRegistry{} - oc.promptRegistry = &promptIntegrationRegistry{} oc.commandRegistry = newCommandIntegrationRegistry() oc.eventRegistry = &eventIntegrationRegistry{} oc.purgeRegistry = &purgeIntegrationRegistry{} oc.approvalRegistry = &toolApprovalIntegrationRegistry{} - oc.integrationModules = make(map[string]any) + oc.integrationModules = make(map[string]integrationruntime.ModuleHooks) oc.integrationOrder = nil host := newRuntimeIntegrationHost(oc) @@ -307,11 +254,8 @@ func (oc *AIClient) initIntegrations() { if toolIntegration, ok := module.(integrationruntime.ToolIntegration); ok { oc.toolRegistry.register(toolIntegration) } - if promptIntegration, ok := module.(integrationruntime.PromptIntegration); ok { - oc.promptRegistry.register(promptIntegration) - } if commandIntegration, ok := module.(integrationruntime.CommandIntegration); ok { - defs := commandIntegration.CommandDefinitions(context.Background(), oc.commandScope(nil, nil, nil)) + defs := commandIntegration.CommandDefinitions(context.Background(), oc.commandScope(nil, nil)) oc.commandRegistry.register(commandIntegration, defs) } if eventIntegration, ok := module.(integrationruntime.EventIntegration); ok { @@ -325,11 +269,9 @@ func (oc *AIClient) initIntegrations() { } } - // Register core integrations after modules so module tool/prompt implementations take precedence. + // Register core integrations after modules so module tool implementations take precedence. coreTools := &coreToolIntegration{client: oc} - corePrompts := &corePromptIntegration{client: oc} oc.toolRegistry.register(coreTools) - oc.promptRegistry.register(corePrompts) registerModuleCommands(oc.commandRegistry.definitions()) } @@ -341,7 +283,7 @@ func (oc *AIClient) integratedToolApprovalRequirement(toolName string, args map[ return oc.approvalRegistry.requirement(toolName, args) } -func (oc *AIClient) registerIntegrationModule(name string, module any) { +func (oc *AIClient) registerIntegrationModule(name string, module integrationruntime.ModuleHooks) { if oc == nil || module == nil { return } @@ -350,7 +292,7 @@ func (oc *AIClient) registerIntegrationModule(name string, module any) { return } if oc.integrationModules == nil { - oc.integrationModules = make(map[string]any) + oc.integrationModules = make(map[string]integrationruntime.ModuleHooks) } if _, exists := oc.integrationModules[key]; exists { return @@ -386,14 +328,14 @@ func (oc *AIClient) emitCompactionLifecycle( } } -func (oc *AIClient) integrationModule(name string) any { +func (oc *AIClient) integrationModule(name string) integrationruntime.ModuleHooks { if oc == nil || oc.integrationModules == nil { return nil } return oc.integrationModules[strings.ToLower(strings.TrimSpace(name))] } -func (oc *AIClient) eachIntegrationModule(fn func(name string, module any)) { +func (oc *AIClient) eachIntegrationModule(fn func(name string, module integrationruntime.ModuleHooks)) { if oc == nil || fn == nil || len(oc.integrationOrder) == 0 { return } @@ -410,7 +352,7 @@ func (oc *AIClient) startLifecycleIntegrations(ctx context.Context) { if oc == nil { return } - oc.eachIntegrationModule(func(name string, module any) { + oc.eachIntegrationModule(func(name string, module integrationruntime.ModuleHooks) { lifecycle, ok := module.(integrationruntime.LifecycleIntegration) if !ok { return @@ -470,7 +412,7 @@ func (oc *AIClient) stopLoginLifecycleIntegrations(bridgeID, loginID string) { if oc == nil || strings.TrimSpace(bridgeID) == "" || strings.TrimSpace(loginID) == "" { return } - oc.eachIntegrationModule(func(_ string, module any) { + oc.eachIntegrationModule(func(_ string, module integrationruntime.ModuleHooks) { loginLifecycle, ok := module.(integrationruntime.LoginLifecycleIntegration) if !ok { return @@ -523,40 +465,10 @@ func (oc *AIClient) executeIntegratedTool( }) } -func (oc *AIClient) additionalSystemMessages( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, -) []openai.ChatCompletionMessageParamUnion { - if oc == nil { - return nil - } - if oc.promptRegistry == nil { - return oc.buildAdditionalSystemPromptsCore(ctx, portal, meta) - } - return oc.promptRegistry.additionalMessages(ctx, oc.promptScope(portal, meta)) -} - -func (oc *AIClient) augmentPromptWithIntegrations( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, -) []openai.ChatCompletionMessageParamUnion { - if oc == nil { - return prompt - } - if oc.promptRegistry == nil { - return prompt - } - return oc.promptRegistry.augmentPrompt(ctx, oc.promptScope(portal, meta), prompt) -} - func (oc *AIClient) executeIntegratedCommand( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, - evt any, name string, args []string, rawArgs string, @@ -569,7 +481,7 @@ func (oc *AIClient) executeIntegratedCommand( Name: name, Args: args, RawArgs: rawArgs, - Scope: oc.commandScope(portal, meta, evt), + Scope: oc.commandScope(portal, meta), Reply: reply, }) } @@ -585,7 +497,6 @@ func (oc *AIClient) emitIntegrationSessionMutation( return } oc.eventRegistry.sessionMutation(ctx, integrationruntime.SessionMutationEvent{ - Client: oc, Portal: portal, Meta: meta, SessionKey: portal.PortalKey.String(), @@ -599,7 +510,6 @@ func (oc *AIClient) emitIntegrationFileChanged(ctx context.Context, portal *brid return } oc.eventRegistry.fileChanged(ctx, integrationruntime.FileChangedEvent{ - Client: oc, Portal: portal, Meta: meta, Path: path, @@ -623,13 +533,12 @@ func notifyIntegrationFileChanged(ctx context.Context, path string) { btc.Client.emitIntegrationFileChanged(ctx, btc.Portal, meta, path) } -func (oc *AIClient) purgeLoginIntegrations(ctx context.Context, login any, bridgeID, loginID string) { +// purgeLoginIntegrations keeps the login argument for parity with the logout cleanup call site. +func (oc *AIClient) purgeLoginIntegrations(ctx context.Context, _ *bridgev2.UserLogin, bridgeID, loginID string) { if oc == nil || oc.purgeRegistry == nil { return } if err := oc.purgeRegistry.purge(ctx, integrationruntime.LoginScope{ - Client: oc, - Login: login, BridgeID: bridgeID, LoginID: loginID, }); err != nil { @@ -696,8 +605,7 @@ func (c *coreToolIntegration) ExecuteTool(ctx context.Context, call integrationr } args = parsedArgs } - portal, _ := call.Scope.Portal.(*bridgev2.Portal) - result, err := c.client.executeBuiltinToolDirect(ctx, portal, call.Name, args) + result, err := c.client.executeBuiltinToolDirect(ctx, call.Scope.Portal, call.Name, args) if err != nil { return true, "", err } @@ -711,29 +619,3 @@ func (c *coreToolIntegration) ToolAvailability( ) (bool, bool, integrationruntime.SettingSource, string) { return false, false, integrationruntime.SourceGlobalDefault, "" } - -type corePromptIntegration struct { - client *AIClient -} - -func (c *corePromptIntegration) Name() string { return "core" } - -func (c *corePromptIntegration) AdditionalSystemMessages( - ctx context.Context, - scope integrationruntime.PromptScope, -) []openai.ChatCompletionMessageParamUnion { - if c == nil || c.client == nil { - return nil - } - portal, _ := scope.Portal.(*bridgev2.Portal) - meta, _ := scope.Meta.(*PortalMetadata) - return c.client.buildAdditionalSystemPromptsCore(ctx, portal, meta) -} - -func (c *corePromptIntegration) AugmentPrompt( - _ context.Context, - _ integrationruntime.PromptScope, - prompt []openai.ChatCompletionMessageParamUnion, -) []openai.ChatCompletionMessageParamUnion { - return prompt -} diff --git a/bridges/ai/integrations_config.go b/bridges/ai/integrations_config.go index ad7298b7..c36f6549 100644 --- a/bridges/ai/integrations_config.go +++ b/bridges/ai/integrations_config.go @@ -2,10 +2,10 @@ package ai import ( _ "embed" + "fmt" "strings" "time" - "go.mau.fi/util/configupgrade" "go.mau.fi/util/ptr" "github.com/beeper/agentremote/pkg/agents" @@ -22,7 +22,6 @@ var exampleNetworkConfig string // the `network:` block in the main bridge config. type Config struct { Beeper BeeperConfig `yaml:"beeper"` - Providers ProvidersConfig `yaml:"providers"` Models *ModelsConfig `yaml:"models"` Bridge BridgeConfig `yaml:"bridge"` Tools ToolProvidersConfig `yaml:"tools"` @@ -38,12 +37,6 @@ type Config struct { DefaultSystemPrompt string `yaml:"default_system_prompt"` ModelCacheDuration time.Duration `yaml:"model_cache_duration"` - // Context pruning configuration - Pruning *airuntime.PruningConfig `yaml:"pruning"` - - // Link preview configuration - LinkPreviews *LinkPreviewConfig `yaml:"link_previews"` - // Inbound message processing configuration Inbound *InboundConfig `yaml:"inbound"` @@ -117,6 +110,12 @@ type AgentDefaultsConfig struct { SkipBootstrap bool `yaml:"skip_bootstrap"` BootstrapMaxChars int `yaml:"bootstrap_max_chars"` TimeoutSeconds int `yaml:"timeout_seconds"` + Model *ModelSelectionConfig `yaml:"model"` + ImageModel *ModelSelectionConfig `yaml:"image_model"` + ImageGeneration *ModelSelectionConfig `yaml:"image_generation_model"` + PDFModel *ModelSelectionConfig `yaml:"pdf_model"` + PDFEngine string `yaml:"pdf_engine"` + Compaction *airuntime.PruningConfig `yaml:"compaction"` SoulEvil *agents.SoulEvilConfig `yaml:"soul_evil"` Heartbeat *HeartbeatConfig `yaml:"heartbeat"` UserTimezone string `yaml:"user_timezone"` @@ -228,11 +227,16 @@ type SessionConfig struct { // ToolProvidersConfig configures external tool providers like search and fetch. type ToolProvidersConfig struct { - Search *SearchConfig `yaml:"search"` - Fetch *FetchConfig `yaml:"fetch"` - Media *MediaToolsConfig `yaml:"media"` - MCP *MCPToolsConfig `yaml:"mcp"` - VFS *VFSToolsConfig `yaml:"vfs"` + Web *WebToolsConfig `yaml:"web"` + Links *LinkPreviewConfig `yaml:"links"` + Media *MediaToolsConfig `yaml:"media"` + MCP *MCPToolsConfig `yaml:"mcp"` + VFS *VFSToolsConfig `yaml:"vfs"` +} + +type WebToolsConfig struct { + Search *SearchConfig `yaml:"search"` + Fetch *FetchConfig `yaml:"fetch"` } // MCPToolsConfig configures generic MCP behavior. @@ -422,21 +426,6 @@ type BeeperConfig struct { Token string `yaml:"token"` // Beeper Matrix access token } -// ProviderConfig holds settings for a specific AI provider. -type ProviderConfig struct { - APIKey string `yaml:"api_key"` - BaseURL string `yaml:"base_url"` - DefaultModel string `yaml:"default_model"` - DefaultPDFEngine string `yaml:"default_pdf_engine"` // pdf-text, mistral-ocr (default), native -} - -// ProvidersConfig contains per-provider configuration. -type ProvidersConfig struct { - Beeper ProviderConfig `yaml:"beeper"` - OpenAI ProviderConfig `yaml:"openai"` - OpenRouter ProviderConfig `yaml:"openrouter"` -} - // ModelsConfig configures model catalog seeding. type ModelsConfig struct { Mode string `yaml:"mode"` // merge | replace @@ -445,7 +434,46 @@ type ModelsConfig struct { // ModelProviderConfig describes models for a specific provider. type ModelProviderConfig struct { - Models []ModelDefinitionConfig `yaml:"models"` + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + Headers map[string]string `yaml:"headers"` + Models []ModelDefinitionConfig `yaml:"models"` +} + +type ModelSelectionConfig struct { + Primary string `yaml:"primary"` + Fallbacks []string `yaml:"fallbacks"` +} + +func (cfg *ModelsConfig) UnmarshalYAML(unmarshal func(any) error) error { + type rawModelsConfig ModelsConfig + var raw rawModelsConfig + if err := unmarshal(&raw); err != nil { + return err + } + if len(raw.Providers) > 0 { + normalizedProviders := make(map[string]ModelProviderConfig, len(raw.Providers)) + for key, provider := range raw.Providers { + normalized := strings.ToLower(strings.TrimSpace(key)) + if normalized == "" { + return fmt.Errorf("models.providers contains an empty provider key") + } + if _, exists := normalizedProviders[normalized]; exists { + return fmt.Errorf("models.providers contains duplicate provider key after normalization: %q", key) + } + normalizedProviders[normalized] = provider + } + raw.Providers = normalizedProviders + } + *cfg = ModelsConfig(raw) + return nil +} + +func (cfg *ModelsConfig) Provider(name string) ModelProviderConfig { + if cfg == nil || len(cfg.Providers) == 0 { + return ModelProviderConfig{} + } + return cfg.Providers[strings.ToLower(strings.TrimSpace(name))] } // ModelDefinitionConfig defines a model entry for catalog seeding. @@ -460,204 +488,3 @@ type ModelDefinitionConfig struct { // BridgeConfig is an alias for the shared bridge config. type BridgeConfig = bridgeconfig.BridgeConfig - -func upgradeConfig(helper configupgrade.Helper) { - // Beeper credentials for auto-login - helper.Copy(configupgrade.Str, "beeper", "user_mxid") - helper.Copy(configupgrade.Str, "beeper", "base_url") - helper.Copy(configupgrade.Str, "beeper", "token") - - // Per-provider default models - helper.Copy(configupgrade.Str, "providers", "beeper", "default_model") - helper.Copy(configupgrade.Str, "providers", "beeper", "default_pdf_engine") - helper.Copy(configupgrade.Str, "providers", "openai", "api_key") - helper.Copy(configupgrade.Str, "providers", "openai", "base_url") - helper.Copy(configupgrade.Str, "providers", "openai", "default_model") - helper.Copy(configupgrade.Str, "providers", "openrouter", "api_key") - helper.Copy(configupgrade.Str, "providers", "openrouter", "base_url") - helper.Copy(configupgrade.Str, "providers", "openrouter", "default_model") - helper.Copy(configupgrade.Str, "providers", "openrouter", "default_pdf_engine") - - // Global settings - helper.Copy(configupgrade.Str, "default_system_prompt") - helper.Copy(configupgrade.Str, "model_cache_duration") - helper.Copy(configupgrade.Str, "memory", "citations") - helper.Copy(configupgrade.Bool, "memory", "inject_context") - - // Tool approvals - helper.Copy(configupgrade.Bool, "tool_approvals", "enabled") - helper.Copy(configupgrade.Int, "tool_approvals", "ttl_seconds") - helper.Copy(configupgrade.Bool, "tool_approvals", "require_for_mcp") - helper.Copy(configupgrade.List, "tool_approvals", "require_for_tools") - - // Bridge-specific configuration - helper.Copy(configupgrade.Str, "bridge", "command_prefix") - - // Context pruning configuration - helper.Copy(configupgrade.Str, "pruning", "mode") - helper.Copy(configupgrade.Str, "pruning", "ttl") - helper.Copy(configupgrade.Bool, "pruning", "enabled") - helper.Copy(configupgrade.Float, "pruning", "soft_trim_ratio") - helper.Copy(configupgrade.Float, "pruning", "hard_clear_ratio") - helper.Copy(configupgrade.Int, "pruning", "keep_last_assistants") - helper.Copy(configupgrade.Int, "pruning", "min_prunable_chars") - helper.Copy(configupgrade.Int, "pruning", "soft_trim_max_chars") - helper.Copy(configupgrade.Int, "pruning", "soft_trim_head_chars") - helper.Copy(configupgrade.Int, "pruning", "soft_trim_tail_chars") - helper.Copy(configupgrade.Bool, "pruning", "hard_clear_enabled") - helper.Copy(configupgrade.Str, "pruning", "hard_clear_placeholder") - - // Compaction configuration (LLM summarization) - helper.Copy(configupgrade.Bool, "pruning", "summarization_enabled") - helper.Copy(configupgrade.Str, "pruning", "summarization_model") - helper.Copy(configupgrade.Int, "pruning", "max_summary_tokens") - helper.Copy(configupgrade.Str, "pruning", "compaction_mode") - helper.Copy(configupgrade.Int, "pruning", "keep_recent_tokens") - helper.Copy(configupgrade.Float, "pruning", "max_history_share") - helper.Copy(configupgrade.Int, "pruning", "reserve_tokens") - helper.Copy(configupgrade.Int, "pruning", "reserve_tokens_floor") - helper.Copy(configupgrade.Str, "pruning", "custom_instructions") - helper.Copy(configupgrade.Str, "pruning", "identifier_policy") - helper.Copy(configupgrade.Str, "pruning", "identifier_instructions") - helper.Copy(configupgrade.Str, "pruning", "post_compaction_refresh_prompt") - helper.Copy(configupgrade.Bool, "pruning", "overflow_flush", "enabled") - helper.Copy(configupgrade.Int, "pruning", "overflow_flush", "soft_threshold_tokens") - helper.Copy(configupgrade.Str, "pruning", "overflow_flush", "prompt") - helper.Copy(configupgrade.Str, "pruning", "overflow_flush", "system_prompt") - - // Link preview configuration - helper.Copy(configupgrade.Bool, "link_previews", "enabled") - helper.Copy(configupgrade.Int, "link_previews", "max_urls_inbound") - helper.Copy(configupgrade.Int, "link_previews", "max_urls_outbound") - helper.Copy(configupgrade.Str, "link_previews", "fetch_timeout") - helper.Copy(configupgrade.Int, "link_previews", "max_content_chars") - helper.Copy(configupgrade.Int, "link_previews", "max_page_bytes") - helper.Copy(configupgrade.Int, "link_previews", "max_image_bytes") - helper.Copy(configupgrade.Str, "link_previews", "cache_ttl") - - // Inbound message processing configuration - helper.Copy(configupgrade.Str, "inbound", "dedupe_ttl") - helper.Copy(configupgrade.Int, "inbound", "dedupe_max_size") - helper.Copy(configupgrade.Int, "inbound", "default_debounce_ms") - - // Cron configuration - helper.Copy(configupgrade.Bool, "cron", "enabled") - - // Messages configuration - helper.Copy(configupgrade.Str, "messages", "ack_reaction") - helper.Copy(configupgrade.Str, "messages", "ack_reaction_scope") - helper.Copy(configupgrade.Bool, "messages", "remove_ack_after") - helper.Copy(configupgrade.Int, "messages", "group_chat", "history_limit") - helper.Copy(configupgrade.List, "messages", "group_chat", "mention_patterns") - helper.Copy(configupgrade.Str, "messages", "group_chat", "activation") - helper.Copy(configupgrade.Int, "messages", "direct_chat", "history_limit") - helper.Copy(configupgrade.Int, "messages", "inbound", "debounce_ms") - helper.Copy(configupgrade.Map, "messages", "inbound", "by_channel") - helper.Copy(configupgrade.List, "commands", "owner_allow_from") - helper.Copy(configupgrade.Str, "messages", "queue", "mode") - helper.Copy(configupgrade.Map, "messages", "queue", "by_channel") - helper.Copy(configupgrade.Int, "messages", "queue", "debounce_ms") - helper.Copy(configupgrade.Map, "messages", "queue", "debounce_ms_by_channel") - helper.Copy(configupgrade.Int, "messages", "queue", "cap") - helper.Copy(configupgrade.Str, "messages", "queue", "drop") - - // Session configuration - helper.Copy(configupgrade.Str, "session", "scope") - helper.Copy(configupgrade.Str, "session", "main_key") - - // Agents heartbeat configuration - helper.Copy(configupgrade.Int, "agents", "defaults", "timeout_seconds") - helper.Copy(configupgrade.Str, "agents", "defaults", "user_timezone") - helper.Copy(configupgrade.Str, "agents", "defaults", "envelope_timezone") - helper.Copy(configupgrade.Str, "agents", "defaults", "envelope_timestamp") - helper.Copy(configupgrade.Str, "agents", "defaults", "envelope_elapsed") - helper.Copy(configupgrade.Str, "agents", "defaults", "typing_mode") - helper.Copy(configupgrade.Int, "agents", "defaults", "typing_interval_seconds") - helper.Copy(configupgrade.Map, "agents", "defaults", "subagents") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "every") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "prompt") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "model") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "session") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "target") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "to") - helper.Copy(configupgrade.Int, "agents", "defaults", "heartbeat", "ack_max_chars") - helper.Copy(configupgrade.Bool, "agents", "defaults", "heartbeat", "include_reasoning") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "active_hours", "start") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "active_hours", "end") - helper.Copy(configupgrade.Str, "agents", "defaults", "heartbeat", "active_hours", "timezone") - helper.Copy(configupgrade.List, "agents", "list") - - // Channels heartbeat visibility - helper.Copy(configupgrade.Bool, "channels", "defaults", "heartbeat", "show_ok") - helper.Copy(configupgrade.Bool, "channels", "defaults", "heartbeat", "show_alerts") - helper.Copy(configupgrade.Bool, "channels", "defaults", "heartbeat", "use_indicator") - helper.Copy(configupgrade.Bool, "channels", "matrix", "heartbeat", "show_ok") - helper.Copy(configupgrade.Bool, "channels", "matrix", "heartbeat", "show_alerts") - helper.Copy(configupgrade.Bool, "channels", "matrix", "heartbeat", "use_indicator") - helper.Copy(configupgrade.Str, "channels", "matrix", "reply_to_mode") - helper.Copy(configupgrade.Str, "channels", "matrix", "thread_replies") - - // Tools (search + fetch) - helper.Copy(configupgrade.Str, "tools", "search", "provider") - helper.Copy(configupgrade.List, "tools", "search", "fallbacks") - helper.Copy(configupgrade.Bool, "tools", "search", "exa", "enabled") - helper.Copy(configupgrade.Str, "tools", "search", "exa", "base_url") - helper.Copy(configupgrade.Str, "tools", "search", "exa", "api_key") - helper.Copy(configupgrade.Str, "tools", "search", "exa", "type") - helper.Copy(configupgrade.Str, "tools", "search", "exa", "category") - helper.Copy(configupgrade.Int, "tools", "search", "exa", "num_results") - helper.Copy(configupgrade.Bool, "tools", "search", "exa", "include_text") - helper.Copy(configupgrade.Int, "tools", "search", "exa", "text_max_chars") - helper.Copy(configupgrade.Bool, "tools", "search", "exa", "highlights") - helper.Copy(configupgrade.Str, "tools", "fetch", "provider") - helper.Copy(configupgrade.List, "tools", "fetch", "fallbacks") - helper.Copy(configupgrade.Bool, "tools", "fetch", "exa", "enabled") - helper.Copy(configupgrade.Str, "tools", "fetch", "exa", "base_url") - helper.Copy(configupgrade.Str, "tools", "fetch", "exa", "api_key") - helper.Copy(configupgrade.Bool, "tools", "fetch", "exa", "include_text") - helper.Copy(configupgrade.Int, "tools", "fetch", "exa", "text_max_chars") - helper.Copy(configupgrade.Bool, "tools", "fetch", "direct", "enabled") - helper.Copy(configupgrade.Int, "tools", "fetch", "direct", "timeout_seconds") - helper.Copy(configupgrade.Str, "tools", "fetch", "direct", "user_agent") - helper.Copy(configupgrade.Bool, "tools", "fetch", "direct", "readability") - helper.Copy(configupgrade.Int, "tools", "fetch", "direct", "max_chars") - helper.Copy(configupgrade.Int, "tools", "fetch", "direct", "max_redirects") - helper.Copy(configupgrade.Int, "tools", "fetch", "direct", "cache_ttl_seconds") - helper.Copy(configupgrade.Bool, "tools", "mcp", "enable_stdio") - helper.Copy(configupgrade.Int, "tools", "media", "image", "max_bytes") - helper.Copy(configupgrade.Int, "tools", "media", "image", "max_chars") - helper.Copy(configupgrade.Int, "tools", "media", "image", "timeout_seconds") - helper.Copy(configupgrade.Int, "tools", "media", "audio", "max_bytes") - helper.Copy(configupgrade.Int, "tools", "media", "audio", "timeout_seconds") - helper.Copy(configupgrade.Int, "tools", "media", "video", "max_bytes") - helper.Copy(configupgrade.Int, "tools", "media", "video", "timeout_seconds") - - // Memory search configuration - helper.Copy(configupgrade.Bool, "memory_search", "enabled") - helper.Copy(configupgrade.List, "memory_search", "sources") - helper.Copy(configupgrade.List, "memory_search", "extra_paths") - helper.Copy(configupgrade.Str, "memory_search", "store", "driver") - helper.Copy(configupgrade.Str, "memory_search", "store", "path") - helper.Copy(configupgrade.Int, "memory_search", "chunking", "tokens") - helper.Copy(configupgrade.Int, "memory_search", "chunking", "overlap") - helper.Copy(configupgrade.Bool, "memory_search", "sync", "on_session_start") - helper.Copy(configupgrade.Bool, "memory_search", "sync", "on_search") - helper.Copy(configupgrade.Bool, "memory_search", "sync", "watch") - helper.Copy(configupgrade.Int, "memory_search", "sync", "watch_debounce_ms") - helper.Copy(configupgrade.Int, "memory_search", "sync", "interval_minutes") - helper.Copy(configupgrade.Int, "memory_search", "sync", "sessions", "delta_bytes") - helper.Copy(configupgrade.Int, "memory_search", "sync", "sessions", "delta_messages") - helper.Copy(configupgrade.Int, "memory_search", "query", "max_results") - helper.Copy(configupgrade.Float, "memory_search", "query", "min_score") - helper.Copy(configupgrade.Int, "memory_search", "query", "hybrid", "candidate_multiplier") - helper.Copy(configupgrade.Bool, "memory_search", "cache", "enabled") - helper.Copy(configupgrade.Int, "memory_search", "cache", "max_entries") - helper.Copy(configupgrade.Bool, "memory_search", "experimental", "session_memory") - - // Tool policy - helper.Copy(configupgrade.Str, "tool_policy", "profile") - helper.Copy(configupgrade.List, "tool_policy", "allow") - helper.Copy(configupgrade.List, "tool_policy", "also_allow") - helper.Copy(configupgrade.List, "tool_policy", "deny") - helper.Copy(configupgrade.Map, "tool_policy", "by_provider") -} diff --git a/bridges/ai/integrations_config_test.go b/bridges/ai/integrations_config_test.go new file mode 100644 index 00000000..81147fc4 --- /dev/null +++ b/bridges/ai/integrations_config_test.go @@ -0,0 +1,43 @@ +package ai + +import ( + "strings" + "testing" + + "gopkg.in/yaml.v3" +) + +func TestModelsConfigProviderMatchesNormalizedKeys(t *testing.T) { + var cfg ModelsConfig + if err := yaml.Unmarshal([]byte(` +mode: merge +providers: + " OpenAI ": + api_key: tok +`), &cfg); err != nil { + t.Fatalf("unmarshal config: %v", err) + } + + got := cfg.Provider("openai") + if got.APIKey != "tok" { + t.Fatalf("expected normalized provider lookup to match, got %#v", got) + } +} + +func TestModelsConfigUnmarshalRejectsNormalizedKeyCollisions(t *testing.T) { + var cfg ModelsConfig + err := yaml.Unmarshal([]byte(` +mode: merge +providers: + OpenAI: + api_key: tok-1 + " openai ": + api_key: tok-2 +`), &cfg) + if err == nil { + t.Fatal("expected duplicate normalized provider keys to fail") + } + if !strings.Contains(err.Error(), "duplicate provider key") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/bridges/ai/integrations_example-config.yaml b/bridges/ai/integrations_example-config.yaml index cab66e41..3bed8a64 100644 --- a/bridges/ai/integrations_example-config.yaml +++ b/bridges/ai/integrations_example-config.yaml @@ -1,157 +1,101 @@ # Connector-specific configuration lives under the `network:` section of the # main config file. -# Beeper Cloud credentials for automatic login (optional). -# If user_mxid, base_url, and token are set, users don't need to manually log in. beeper: - user_mxid: "" # Owning Matrix user for the built-in Beeper Cloud login. - base_url: "" # Optional. If empty, login uses selected Beeper domain. - token: "" # Beeper Matrix access token + user_mxid: "" + base_url: "" + token: "" -# Per-provider default models and settings. -# These are used when a room doesn't have a specific model configured. -providers: - beeper: - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - openai: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://api.openai.com/v1 - base_url: "https://api.openai.com/v1" - default_model: "openai/gpt-5.4" - openrouter: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://openrouter.ai/api/v1 - base_url: "https://openrouter.ai/api/v1" - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - -# Optional model catalog seeding. -# models: -# mode: "merge" # merge | replace -# providers: -# openai: -# models: -# - id: "gpt-5.2" -# name: "GPT-5.2" -# reasoning: true -# input: ["text", "image"] -# context_window: 128000 -# max_tokens: 8192 +models: + providers: + openai: + api_key: "" + base_url: "https://api.openai.com/v1" + models: [] + openrouter: + api_key: "" + base_url: "https://openrouter.ai/api/v1" + models: [] + magic_proxy: + api_key: "" + base_url: "" + models: [] -# Global settings default_system_prompt: | You are a helpful, concise assistant. Ask clarifying questions when needed. Follow the user's intent and be accurate. model_cache_duration: 6h -# Optional message rendering settings. messages: - # History defaults for prompt construction. - # Set 0 to disable. direct_chat: history_limit: 20 group_chat: history_limit: 50 - # Queue behavior while the agent is busy. queue: - # Modes: collect, followup, steer, steer-backlog, interrupt mode: "collect" - # Debounce time before draining queued messages (ms). debounce_ms: 1000 - # Maximum queued messages before drop policy applies. cap: 20 - # Drop policy when cap is exceeded: summarize, old, new drop: "summarize" -# Command authorization settings. commands: - # Optional allowlist for owner-only tools/commands (Matrix IDs, or "matrix:@user:server"). owner_allow_from: [] -# Tool approval gating. tool_approvals: enabled: true ttl_seconds: 600 require_for_mcp: true - # List of builtin tool names that require approval (subject to per-tool action allowlists). - # Note: `message` approvals apply to Desktop API routing too (e.g. action=send/reply/edit with desktop chat hints), - # while Desktop read-only actions like desktop-search-* do not require approval. require_for_tools: ["message", "cron", "gravatar_set", "create_agent", "fork_agent", "edit_agent", "delete_agent", "modify_room", "sessions_send", "sessions_spawn", "run_internal_command"] - # Fallback when approval times out: "deny" (default) | "allow". - # Set to "allow" for cron/automated contexts where no human can respond. -# Optional per-channel overrides. channels: matrix: - # Matrix reply/thread behavior. reply_to_mode: "first" -# Session configuration. session: - # Scope for session state: per-sender (default) or global. scope: "per-sender" - # Main session key alias (default: "main"). main_key: "main" -# External tool providers (search + fetch). Proxy is optional. tools: - search: - provider: "openrouter" - fallbacks: ["exa", "brave", "perplexity"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - type: "auto" - num_results: 5 - include_text: false - text_max_chars: 500 - highlights: true # enabled by default; provides description snippets for source cards - brave: - api_key: "" - base_url: "https://api.search.brave.com/res/v1/web/search" - perplexity: - api_key: "" - base_url: "https://openrouter.ai/api/v1" - model: "perplexity/sonar-pro" - openrouter: - api_key: "" - base_url: "https://openrouter.ai/api/v1" - model: "openai/gpt-5.4" - fetch: - provider: "exa" - fallbacks: ["direct"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - include_text: true - text_max_chars: 5000 - direct: - enabled: true - timeout_seconds: 30 - max_chars: 50000 - max_redirects: 3 - - # Generic MCP behavior. + web: + search: + provider: "exa" + fallbacks: [] + exa: + api_key: "" + base_url: "https://api.exa.ai" + type: "auto" + num_results: 5 + include_text: false + text_max_chars: 500 + highlights: true + fetch: + provider: "direct" + fallbacks: ["exa"] + exa: + api_key: "" + base_url: "https://api.exa.ai" + include_text: true + text_max_chars: 5000 + direct: + enabled: true + timeout_seconds: 30 + max_chars: 50000 + max_redirects: 3 + links: + enabled: true + max_urls_inbound: 3 + max_urls_outbound: 5 + fetch_timeout: 10s + max_content_chars: 500 + max_page_bytes: 10485760 + max_image_bytes: 5242880 + cache_ttl: 1h mcp: - # Disabled by default for safety. Enable explicitly to allow local stdio MCP servers. enable_stdio: false - - # Virtual filesystem tools. vfs: apply_patch: enabled: false allow_models: [] - - # Media understanding/transcription. - # Supports provider/CLI entries and per-capability defaults. media: concurrency: 2 image: @@ -167,7 +111,6 @@ tools: enabled: true prompt: "Transcribe the audio." language: "" - # CLI transcription auto-detection (whisper/whisper.cpp) is not implemented yet. max_bytes: 20971520 timeout_seconds: 60 models: @@ -181,184 +124,47 @@ tools: models: - provider: "openrouter" model: "google/gemini-3-flash-preview" - chunking: - tokens: 400 - overlap: 80 - sync: - on_session_start: true - on_search: true - watch: true - watch_debounce_ms: 1500 - interval_minutes: 0 - sessions: - delta_bytes: 100000 - delta_messages: 50 - query: - max_results: 6 - min_score: 0.35 - hybrid: - candidate_multiplier: 4 - cache: - enabled: true - max_entries: -1 # Unlimited when cache.enabled is true; experimental.session_memory does not change cache size semantics. - experimental: - session_memory: false - -# Recall configuration. -# recall: -# citations: "auto" # auto | on | off -# inject_context: false # default false. when true, injects MEMORY.md snippets as extra system context. - - # Tool policy. Controls allow/deny lists and profiles. - # tool_policy: - # profile: "full" - # # group:openclaw is the strict OpenClaw native tool set. - # # group:ai-bridge includes ai-bridge-only extras (beeper_docs, gravatar_*, tts, image_generate, calculator, etc). - # allow: ["group:openclaw", "group:ai-bridge"] - # deny: [] - # subagents: - # tools: - # deny: ["sessions_list", "sessions_history", "sessions_send"] - - # Agent defaults. - # agents: - # defaults: - # subagents: - # model: "anthropic/claude-sonnet-4.5" - # allow_agents: ["*"] - # skip_bootstrap: false - # bootstrap_max_chars: 20000 - # typing_mode: "instant" # never|instant|thinking|message (message ignores NO_REPLY) - # typing_interval_seconds: 6 # refresh cadence, not start time (heartbeats never show typing) - # soul_evil: - # file: "SOUL_EVIL.md" - # chance: 0.1 - # purge: - # at: "21:00" - # duration: "15m" - -# Context pruning configuration. -# Reduces token usage by intelligently truncating old tool results. -pruning: - # Pruning mode: off | cache-ttl - # cache-ttl is the default pruning mode. - mode: "cache-ttl" - - # Refresh interval for cache-ttl mode. - ttl: "1h" - - # Enable proactive context pruning - enabled: true - - # Ratio of context window usage that triggers soft trimming (0.0-1.0) - # At 30% usage, large tool results start getting truncated - soft_trim_ratio: 0.3 - - # Ratio of context window usage that triggers hard clearing (0.0-1.0) - # At 50% usage, old tool results are replaced with placeholder - hard_clear_ratio: 0.5 - - # Number of recent assistant messages to protect from pruning - keep_last_assistants: 3 - # Minimum total chars in prunable tool results before hard clear kicks in - min_prunable_chars: 50000 - - # Tool results larger than this are candidates for soft trimming - soft_trim_max_chars: 4000 - - # When soft trimming, keep this many chars from the start - soft_trim_head_chars: 1500 - - # When soft trimming, keep this many chars from the end - soft_trim_tail_chars: 1500 - - # Enable/disable hard clear phase - hard_clear_enabled: true - - # Placeholder text for hard-cleared tool results - hard_clear_placeholder: "[Old tool result content cleared]" - - # Tool patterns to allow/deny pruning (supports wildcards: list_*, *_search) - # Empty means all tools are prunable unless denied - # tools_allow: [] - # tools_deny: [] - - # --- LLM-based summarization (compaction) --- - # When enabled, uses an LLM to generate intelligent summaries of compacted - # content instead of just using placeholder text. This preserves context better. - - # Enable LLM summarization (default: true when pruning is enabled) - summarization_enabled: true - - # Model to use for generating summaries (default: fast model) - summarization_model: "openai/gpt-5.4" - - # Maximum tokens for generated summaries - max_summary_tokens: 500 - - # Compaction mode: - # - default: balanced reduction - # - safeguard: preserves recent context more aggressively - compaction_mode: "safeguard" - - # Minimum recent token budget preserved during safeguard compaction - keep_recent_tokens: 20000 - - # Maximum ratio of context that history can consume (0.0-1.0) - # When exceeded, oldest messages are summarized to fit budget - max_history_share: 0.5 - - # Token budget reserved for compaction output - reserve_tokens: 20000 - # Floor applied to reserve_tokens to avoid aggressive overfill - reserve_tokens_floor: 20000 - - # Optional post-compaction system context injected before retry - post_compaction_refresh_prompt: "[Post-compaction context refresh]\nRe-anchor to the latest user intent and preserve unresolved tasks and identifiers." - - # Additional instructions for the summarization model - # custom_instructions: "Focus on preserving code decisions and TODOs" - - # Identifier preservation policy for summaries: - # - strict (default): preserve opaque identifiers exactly - # - off: no special identifier-preservation instruction - # - custom: use identifier_instructions below - identifier_policy: "strict" - # identifier_instructions: "Keep ticket IDs, hashes, and hostnames unchanged." - - # Optional pre-compaction overflow flush turn. - # Enabled by default. Disable explicitly if you want no pre-flush. - overflow_flush: - enabled: true - soft_threshold_tokens: 4000 - prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." - system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." - -# Link preview configuration. -# Automatically fetches metadata for URLs in messages to provide context to the AI -# and generate rich previews in outgoing AI responses. -link_previews: - # Enable link preview functionality (default: true) - enabled: true - - # Maximum number of URLs to fetch from user messages for AI context (default: 3) - max_urls_inbound: 3 - - # Maximum number of URLs to preview in AI responses (default: 5) - max_urls_outbound: 5 - - # Timeout for fetching each URL (default: 10s) - fetch_timeout: 10s - - # Maximum characters from description to include in context (default: 500) - max_content_chars: 500 - - # Maximum page size to download in bytes (default: 10MB) - max_page_bytes: 10485760 - - # Maximum image size to download in bytes (default: 5MB) - max_image_bytes: 5242880 - - # How long to cache URL previews (default: 1h) - cache_ttl: 1h +agents: + defaults: + model: + primary: "" + fallbacks: [] + image_model: + primary: "" + fallbacks: [] + image_generation_model: + primary: "" + fallbacks: [] + pdf_model: + primary: "" + fallbacks: [] + pdf_engine: "mistral-ocr" + compaction: + mode: "cache-ttl" + ttl: "1h" + enabled: true + soft_trim_ratio: 0.3 + hard_clear_ratio: 0.5 + keep_last_assistants: 3 + min_prunable_chars: 50000 + soft_trim_max_chars: 4000 + soft_trim_head_chars: 1500 + soft_trim_tail_chars: 1500 + hard_clear_enabled: true + hard_clear_placeholder: "[Old tool result content cleared]" + summarization_enabled: true + summarization_model: "openai/gpt-5.2" + max_summary_tokens: 500 + compaction_mode: "safeguard" + keep_recent_tokens: 20000 + max_history_share: 0.5 + reserve_tokens: 20000 + reserve_tokens_floor: 20000 + identifier_policy: "strict" + post_compaction_refresh_prompt: "[Post-compaction context refresh]\nRe-anchor to the latest user intent and preserve unresolved tasks and identifiers." + overflow_flush: + enabled: true + soft_threshold_tokens: 4000 + prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." + system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." diff --git a/bridges/ai/integrations_example_config_test.go b/bridges/ai/integrations_example_config_test.go new file mode 100644 index 00000000..a5ca1a1b --- /dev/null +++ b/bridges/ai/integrations_example_config_test.go @@ -0,0 +1,68 @@ +package ai + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "gopkg.in/yaml.v3" +) + +func TestExampleConfigFilesExcludeLegacyConfigSections(t *testing.T) { + legacyKeys := map[string]struct{}{ + "chunking": {}, + "sync": {}, + "query": {}, + "cache": {}, + "experimental": {}, + "pruning": {}, + "recall": {}, + } + + t.Run("bridge example", func(t *testing.T) { + rel := "integrations_example-config.yaml" + data, err := os.ReadFile(rel) + if err != nil { + t.Fatalf("read %s: %v", rel, err) + } + + var doc map[string]any + if err := yaml.Unmarshal(data, &doc); err != nil { + t.Fatalf("unmarshal %s: %v", rel, err) + } + + if path := findLegacyConfigKey(doc, legacyKeys, nil); path != "" { + t.Fatalf("found legacy config key %q in %s", path, rel) + } + }) + + t.Run("legacy connector example removed", func(t *testing.T) { + rel := filepath.Join("..", "..", "pkg", "connector", "integrations_example-config.yaml") + if _, err := os.Stat(rel); !os.IsNotExist(err) { + t.Fatalf("expected stale generic example %s to be removed, got err=%v", rel, err) + } + }) +} + +func findLegacyConfigKey(node any, legacyKeys map[string]struct{}, path []string) string { + switch value := node.(type) { + case map[string]any: + for key, child := range value { + if _, ok := legacyKeys[key]; ok { + return strings.Join(append(path, key), ".") + } + if found := findLegacyConfigKey(child, legacyKeys, append(path, key)); found != "" { + return found + } + } + case []any: + for idx, child := range value { + if found := findLegacyConfigKey(child, legacyKeys, append(path, fmt.Sprintf("[%d]", idx))); found != "" { + return found + } + } + } + return "" +} diff --git a/bridges/ai/integrations_test.go b/bridges/ai/integrations_test.go index 517def4f..5be5fea4 100644 --- a/bridges/ai/integrations_test.go +++ b/bridges/ai/integrations_test.go @@ -3,11 +3,8 @@ package ai import ( "context" "reflect" - "slices" "testing" - "github.com/openai/openai-go/v3" - integrationruntime "github.com/beeper/agentremote/pkg/integrations/runtime" ) @@ -30,28 +27,6 @@ func (f fakeToolIntegration) ToolAvailability(_ context.Context, _ integrationru return false, false, integrationruntime.SourceGlobalDefault, "" } -type fakePromptIntegration struct { - name string - tag string -} - -func (f fakePromptIntegration) Name() string { return f.name } - -func (f fakePromptIntegration) AdditionalSystemMessages(_ context.Context, _ integrationruntime.PromptScope) []openai.ChatCompletionMessageParamUnion { - return []openai.ChatCompletionMessageParamUnion{openai.SystemMessage("sys:" + f.tag)} -} - -func (f fakePromptIntegration) AugmentPrompt( - _ context.Context, - _ integrationruntime.PromptScope, - prompt []openai.ChatCompletionMessageParamUnion, -) []openai.ChatCompletionMessageParamUnion { - out := make([]openai.ChatCompletionMessageParamUnion, 0, len(prompt)+1) - out = append(out, prompt...) - out = append(out, openai.UserMessage("aug:"+f.tag)) - return out -} - func TestToolIntegrationRegistryDefinitionsDeterministic(t *testing.T) { reg := &toolIntegrationRegistry{} reg.register(fakeToolIntegration{name: "one", defs: []integrationruntime.ToolDefinition{{Name: "a"}, {Name: "b"}}}) @@ -68,41 +43,6 @@ func TestToolIntegrationRegistryDefinitionsDeterministic(t *testing.T) { } } -func TestPromptIntegrationRegistryOrder(t *testing.T) { - reg := &promptIntegrationRegistry{} - reg.register(fakePromptIntegration{name: "one", tag: "1"}) - reg.register(fakePromptIntegration{name: "two", tag: "2"}) - - sys := reg.additionalMessages(context.Background(), integrationruntime.PromptScope{}) - if len(sys) != 2 { - t.Fatalf("expected 2 system messages, got %d", len(sys)) - } - - base := []openai.ChatCompletionMessageParamUnion{openai.UserMessage("base")} - out := reg.augmentPrompt(context.Background(), integrationruntime.PromptScope{}, base) - if len(out) != 3 { - t.Fatalf("expected augmented prompt len=3, got %d", len(out)) - } -} - -func TestPromptIntegrationRegistryAugmentPromptIdempotent(t *testing.T) { - reg := &promptIntegrationRegistry{} - reg.register(fakePromptIntegration{name: "one", tag: "1"}) - reg.register(fakePromptIntegration{name: "two", tag: "2"}) - - base := []openai.ChatCompletionMessageParamUnion{openai.UserMessage("base")} - baseCopy := slices.Clone(base) - - outA := reg.augmentPrompt(context.Background(), integrationruntime.PromptScope{}, base) - outB := reg.augmentPrompt(context.Background(), integrationruntime.PromptScope{}, base) - if !reflect.DeepEqual(outA, outB) { - t.Fatalf("augmentPrompt should be deterministic/idempotent; got outA=%v outB=%v", outA, outB) - } - if !reflect.DeepEqual(base, baseCopy) { - t.Fatalf("augmentPrompt mutated input prompt; got=%v want=%v", base, baseCopy) - } -} - type fakeLifecycleIntegration struct { startCount int stopCount int @@ -110,6 +50,8 @@ type fakeLifecycleIntegration struct { name string } +func (f *fakeLifecycleIntegration) Name() string { return f.name } + func (f *fakeLifecycleIntegration) Start(_ context.Context) error { f.startCount++ return nil diff --git a/bridges/ai/internal_dispatch.go b/bridges/ai/internal_dispatch.go index 21fddb9e..d4cfd657 100644 --- a/bridges/ai/internal_dispatch.go +++ b/bridges/ai/internal_dispatch.go @@ -44,7 +44,7 @@ func (oc *AIClient) dispatchInternalMessage( inboundCtx := oc.resolvePromptInboundContext(ctx, portal, trimmed, eventID) promptCtx := withInboundContext(ctx, inboundCtx) - promptContext, err := oc.buildContextWithLinkContext(promptCtx, portal, meta, trimmed, nil, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(promptCtx, portal, meta, trimmed, nil, eventID) if err != nil { return eventID, false, err } @@ -60,7 +60,6 @@ func (oc *AIClient) dispatchInternalMessage( Timestamp: time.Now(), } setCanonicalTurnDataFromPromptMessages(userMessage.Metadata.(*MessageMetadata), promptTail(promptContext, 1)) - ensureCanonicalUserMessage(userMessage) if _, err := oc.UserLogin.Bridge.GetGhostByID(ctx, userMessage.SenderID); err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure user ghost before saving internal message") } diff --git a/bridges/ai/login.go b/bridges/ai/login.go index 00532fbd..4140e491 100644 --- a/bridges/ai/login.go +++ b/bridges/ai/login.go @@ -220,10 +220,17 @@ func (ol *OpenAILogin) finishLogin(ctx context.Context, provider, apiKey, baseUR meta = &UserLoginMetadata{} } meta.Provider = provider - meta.APIKey = apiKey - meta.BaseURL = baseURL + creds := &LoginCredentials{ + APIKey: apiKey, + BaseURL: baseURL, + } if serviceTokens != nil && !serviceTokensEmpty(serviceTokens) { - meta.ServiceTokens = mergeServiceTokens(meta.ServiceTokens, serviceTokens) + creds.ServiceTokens = cloneServiceTokens(serviceTokens) + } + if loginCredentialsEmpty(creds) { + meta.Credentials = nil + } else { + meta.Credentials = creds } if err := ol.validateLoginMetadata(ctx, loginID, meta); err != nil { return nil, err @@ -331,46 +338,12 @@ func (ol *OpenAILogin) validateLoginMetadata(ctx context.Context, loginID networ return nil } -func serviceTokensEmpty(tokens *ServiceTokens) bool { - if tokens == nil { - return true - } - if len(tokens.DesktopAPIInstances) > 0 { - for _, instance := range tokens.DesktopAPIInstances { - if strings.TrimSpace(instance.Token) != "" || strings.TrimSpace(instance.BaseURL) != "" { - return false - } - } - } - if len(tokens.MCPServers) > 0 { - for _, server := range tokens.MCPServers { - if strings.TrimSpace(server.Transport) != "" || - strings.TrimSpace(server.Endpoint) != "" || - strings.TrimSpace(server.Command) != "" || - len(server.Args) > 0 || - strings.TrimSpace(server.Token) != "" || - strings.TrimSpace(server.AuthURL) != "" || - strings.TrimSpace(server.AuthType) != "" || - strings.TrimSpace(server.Kind) != "" || - server.Connected { - return false - } - } - } - return strings.TrimSpace(tokens.OpenAI) == "" && - strings.TrimSpace(tokens.OpenRouter) == "" && - strings.TrimSpace(tokens.Exa) == "" && - strings.TrimSpace(tokens.Brave) == "" && - strings.TrimSpace(tokens.Perplexity) == "" && - strings.TrimSpace(tokens.DesktopAPI) == "" -} - func (ol *OpenAILogin) resolveCustomLogin(input map[string]string) (string, string, *ServiceTokens, error) { if input == nil { input = map[string]string{} } - openrouterCfg := strings.TrimSpace(ol.Connector.Config.Providers.OpenRouter.APIKey) - openaiCfg := strings.TrimSpace(ol.Connector.Config.Providers.OpenAI.APIKey) + openrouterCfg := strings.TrimSpace(ol.Connector.modelProviderConfig(ProviderOpenRouter).APIKey) + openaiCfg := strings.TrimSpace(ol.Connector.modelProviderConfig(ProviderOpenAI).APIKey) openrouterInput := "" openaiInput := "" @@ -478,18 +451,22 @@ func parseMagicProxyLink(raw string) (string, string, error) { } func (ol *OpenAILogin) configHasOpenRouterKey() bool { - return strings.TrimSpace(ol.Connector.Config.Providers.OpenRouter.APIKey) != "" + return strings.TrimSpace(ol.Connector.modelProviderConfig(ProviderOpenRouter).APIKey) != "" } func (ol *OpenAILogin) configHasOpenAIKey() bool { - return strings.TrimSpace(ol.Connector.Config.Providers.OpenAI.APIKey) != "" + return strings.TrimSpace(ol.Connector.modelProviderConfig(ProviderOpenAI).APIKey) != "" } func (ol *OpenAILogin) configHasExaKey() bool { - if ol.Connector.Config.Tools.Search != nil && strings.TrimSpace(ol.Connector.Config.Tools.Search.Exa.APIKey) != "" { + if ol.Connector.Config.Tools.Web != nil && + ol.Connector.Config.Tools.Web.Search != nil && + strings.TrimSpace(ol.Connector.Config.Tools.Web.Search.Exa.APIKey) != "" { return true } - if ol.Connector.Config.Tools.Fetch != nil && strings.TrimSpace(ol.Connector.Config.Tools.Fetch.Exa.APIKey) != "" { + if ol.Connector.Config.Tools.Web != nil && + ol.Connector.Config.Tools.Web.Fetch != nil && + strings.TrimSpace(ol.Connector.Config.Tools.Web.Fetch.Exa.APIKey) != "" { return true } return false diff --git a/bridges/ai/login_loaders.go b/bridges/ai/login_loaders.go index 189a6998..c18bcc49 100644 --- a/bridges/ai/login_loaders.go +++ b/bridges/ai/login_loaders.go @@ -34,13 +34,13 @@ func aiClientNeedsRebuild(existing *AIClient, key string, meta *UserLoginMetadat existingBaseURL := "" if existingMeta != nil { existingProvider = strings.TrimSpace(existingMeta.Provider) - existingBaseURL = stringutil.NormalizeBaseURL(existingMeta.BaseURL) + existingBaseURL = stringutil.NormalizeBaseURL(loginCredentialBaseURL(existingMeta)) } targetProvider := "" targetBaseURL := "" if meta != nil { targetProvider = strings.TrimSpace(meta.Provider) - targetBaseURL = stringutil.NormalizeBaseURL(meta.BaseURL) + targetBaseURL = stringutil.NormalizeBaseURL(loginCredentialBaseURL(meta)) } return existing.apiKey != key || !strings.EqualFold(existingProvider, targetProvider) || diff --git a/bridges/ai/login_loaders_test.go b/bridges/ai/login_loaders_test.go index ab02a1fc..2f3b57aa 100644 --- a/bridges/ai/login_loaders_test.go +++ b/bridges/ai/login_loaders_test.go @@ -27,19 +27,19 @@ func testUserLoginWithMeta(loginID networkid.UserLoginID, meta *UserLoginMetadat func TestAIClientNeedsRebuild(t *testing.T) { existing := &AIClient{ apiKey: "secret", - UserLogin: testUserLoginWithMeta("existing", &UserLoginMetadata{Provider: " OpenAI ", BaseURL: "https://api.example.com/v1/"}), + UserLogin: testUserLoginWithMeta("existing", &UserLoginMetadata{Provider: " OpenAI ", Credentials: &LoginCredentials{BaseURL: "https://api.example.com/v1/"}}), } - if aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openai", BaseURL: "https://api.example.com/v1"}) { + if aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openai", Credentials: &LoginCredentials{BaseURL: "https://api.example.com/v1"}}) { t.Fatal("expected no rebuild when key/provider/base URL are equivalent") } - if !aiClientNeedsRebuild(existing, "other-key", &UserLoginMetadata{Provider: "openai", BaseURL: "https://api.example.com/v1"}) { + if !aiClientNeedsRebuild(existing, "other-key", &UserLoginMetadata{Provider: "openai", Credentials: &LoginCredentials{BaseURL: "https://api.example.com/v1"}}) { t.Fatal("expected rebuild when API key changes") } - if !aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openrouter", BaseURL: "https://api.example.com/v1"}) { + if !aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openrouter", Credentials: &LoginCredentials{BaseURL: "https://api.example.com/v1"}}) { t.Fatal("expected rebuild when provider changes") } - if !aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openai", BaseURL: "https://api.other.example.com/v1"}) { + if !aiClientNeedsRebuild(existing, "secret", &UserLoginMetadata{Provider: "openai", Credentials: &LoginCredentials{BaseURL: "https://api.other.example.com/v1"}}) { t.Fatal("expected rebuild when base URL changes") } if !aiClientNeedsRebuild(nil, "secret", &UserLoginMetadata{Provider: "openai"}) { diff --git a/bridges/ai/magic_proxy_test.go b/bridges/ai/magic_proxy_test.go index 9e16c78c..35550a09 100644 --- a/bridges/ai/magic_proxy_test.go +++ b/bridges/ai/magic_proxy_test.go @@ -39,8 +39,10 @@ func TestResolveServiceConfigMagicProxyUsesJoinedPaths(t *testing.T) { oc := &OpenAIConnector{} meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } services := oc.resolveServiceConfig(meta) @@ -48,9 +50,15 @@ func TestResolveServiceConfigMagicProxyUsesJoinedPaths(t *testing.T) { if got := services[serviceOpenRouter].BaseURL; got != "https://bai.bt.hn/team/proxy/openrouter/v1" { t.Fatalf("unexpected openrouter base URL: %q", got) } + if got := services[serviceOpenRouter].APIKey; got != "tok" { + t.Fatalf("unexpected openrouter api key: %q", got) + } if got := services[serviceOpenAI].BaseURL; got != "https://bai.bt.hn/team/proxy/openai/v1" { t.Fatalf("unexpected openai base URL: %q", got) } + if got := services[serviceOpenAI].APIKey; got != "tok" { + t.Fatalf("unexpected openai api key: %q", got) + } if got := services[serviceGemini].BaseURL; got != "https://bai.bt.hn/team/proxy/gemini/v1beta" { t.Fatalf("unexpected gemini base URL: %q", got) } @@ -63,8 +71,10 @@ func TestResolveServiceConfigMagicProxyNoDuplicateOpenRouterPath(t *testing.T) { oc := &OpenAIConnector{} meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy/openrouter/v1", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy/openrouter/v1", + }, } services := oc.resolveServiceConfig(meta) @@ -72,6 +82,9 @@ func TestResolveServiceConfigMagicProxyNoDuplicateOpenRouterPath(t *testing.T) { if strings.Count(base, "/openrouter/v1") != 1 { t.Fatalf("openrouter path duplicated: %q", base) } + if got := services[serviceOpenRouter].APIKey; got != "tok" { + t.Fatalf("unexpected openrouter api key: %q", got) + } if got := services[serviceExa].BaseURL; got != "https://bai.bt.hn/team/proxy/exa" { t.Fatalf("unexpected exa base URL: %q", got) } @@ -87,7 +100,9 @@ func TestResolveExaProxyBaseURLMagicProxyPrefersLoginBase(t *testing.T) { } meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - BaseURL: "https://ai.bt.hn/", + Credentials: &LoginCredentials{ + BaseURL: "https://ai.bt.hn/", + }, } if got := oc.resolveExaProxyBaseURL(meta); got != "https://ai.bt.hn/exa" { t.Fatalf("unexpected exa proxy base: %q", got) diff --git a/bridges/ai/mcp_helpers.go b/bridges/ai/mcp_helpers.go index 88b5aa86..689e8144 100644 --- a/bridges/ai/mcp_helpers.go +++ b/bridges/ai/mcp_helpers.go @@ -66,24 +66,32 @@ func (oc *AIClient) verifyMCPServerConnection(ctx context.Context, server namedM } func setLoginMCPServer(meta *UserLoginMetadata, name string, cfg MCPServerConfig) { - if meta.ServiceTokens == nil { - meta.ServiceTokens = &ServiceTokens{} + creds := ensureLoginCredentials(meta) + if creds == nil { + return + } + if creds.ServiceTokens == nil { + creds.ServiceTokens = &ServiceTokens{} } - if meta.ServiceTokens.MCPServers == nil { - meta.ServiceTokens.MCPServers = map[string]MCPServerConfig{} + if creds.ServiceTokens.MCPServers == nil { + creds.ServiceTokens.MCPServers = map[string]MCPServerConfig{} } - meta.ServiceTokens.MCPServers[name] = normalizeMCPServerConfig(cfg) + creds.ServiceTokens.MCPServers[name] = normalizeMCPServerConfig(cfg) } func clearLoginMCPServer(meta *UserLoginMetadata, name string) { - if meta == nil || meta.ServiceTokens == nil || meta.ServiceTokens.MCPServers == nil { + creds := loginCredentials(meta) + if creds == nil || creds.ServiceTokens == nil || creds.ServiceTokens.MCPServers == nil { return } - delete(meta.ServiceTokens.MCPServers, name) - if len(meta.ServiceTokens.MCPServers) == 0 { - meta.ServiceTokens.MCPServers = nil + delete(creds.ServiceTokens.MCPServers, name) + if len(creds.ServiceTokens.MCPServers) == 0 { + creds.ServiceTokens.MCPServers = nil + } + if serviceTokensEmpty(creds.ServiceTokens) { + creds.ServiceTokens = nil } - if serviceTokensEmpty(meta.ServiceTokens) { - meta.ServiceTokens = nil + if loginCredentialsEmpty(creds) { + meta.Credentials = nil } } diff --git a/bridges/ai/mcp_servers.go b/bridges/ai/mcp_servers.go index 64c188ce..8f88460e 100644 --- a/bridges/ai/mcp_servers.go +++ b/bridges/ai/mcp_servers.go @@ -144,12 +144,12 @@ func (oc *AIClient) loginMCPServers() map[string]MCPServerConfig { if oc == nil || oc.UserLogin == nil { return nil } - meta := loginMetadata(oc.UserLogin) - if meta == nil || meta.ServiceTokens == nil || len(meta.ServiceTokens.MCPServers) == 0 { + tokens := loginCredentialServiceTokens(loginMetadata(oc.UserLogin)) + if tokens == nil || len(tokens.MCPServers) == 0 { return nil } - out := make(map[string]MCPServerConfig, len(meta.ServiceTokens.MCPServers)) - for rawName, rawCfg := range meta.ServiceTokens.MCPServers { + out := make(map[string]MCPServerConfig, len(tokens.MCPServers)) + for rawName, rawCfg := range tokens.MCPServers { name := normalizeMCPServerName(rawName) cfg := normalizeMCPServerConfig(rawCfg) if name == "" { diff --git a/bridges/ai/mcp_servers_test.go b/bridges/ai/mcp_servers_test.go index 3dfe93d8..973f5365 100644 --- a/bridges/ai/mcp_servers_test.go +++ b/bridges/ai/mcp_servers_test.go @@ -10,7 +10,7 @@ import ( ) func testAIClientWithMCPServers(servers map[string]MCPServerConfig) *AIClient { - meta := &UserLoginMetadata{ServiceTokens: &ServiceTokens{MCPServers: servers}} + meta := &UserLoginMetadata{Credentials: &LoginCredentials{ServiceTokens: &ServiceTokens{MCPServers: servers}}} login := &database.UserLogin{ID: networkid.UserLoginID("login"), Metadata: meta} userLogin := &bridgev2.UserLogin{UserLogin: login, Log: zerolog.Nop()} return &AIClient{UserLogin: userLogin} diff --git a/bridges/ai/media_understanding_providers.go b/bridges/ai/media_understanding_providers.go index 513d34f1..57b7f5e0 100644 --- a/bridges/ai/media_understanding_providers.go +++ b/bridges/ai/media_understanding_providers.go @@ -400,9 +400,6 @@ type mediaAudioRequest struct { FileName string } -type mediaVideoRequest = mediaRequestBase -type mediaImageRequest = mediaRequestBase - func resolveMediaFileName(fallback string, msgType string, mediaURL string) string { base := strings.TrimSpace(fallback) if base != "" { diff --git a/bridges/ai/media_understanding_runner.go b/bridges/ai/media_understanding_runner.go index 1b468891..69253d89 100644 --- a/bridges/ai/media_understanding_runner.go +++ b/bridges/ai/media_understanding_runner.go @@ -17,7 +17,6 @@ import ( "maunium.net/go/mautrix/event" "github.com/beeper/agentremote/pkg/shared/stringutil" - bridgesdk "github.com/beeper/agentremote/sdk" ) type mediaUnderstandingResult struct { @@ -211,11 +210,9 @@ func (oc *AIClient) applyMediaUnderstandingForAttachments( } func (oc *AIClient) resolveAutoAudioEntry(cfg *MediaUnderstandingConfig) *MediaUnderstandingModelConfig { - headers := map[string]string{} - if cfg != nil && cfg.Headers != nil { - for key, value := range cfg.Headers { - headers[key] = value - } + var headers map[string]string + if cfg != nil { + headers = cfg.Headers } candidates := []struct { @@ -333,11 +330,9 @@ func (oc *AIClient) resolveKeyMediaEntry( } func (oc *AIClient) hasMediaProviderAuth(providerID string, cfg *MediaUnderstandingConfig) bool { - headers := map[string]string{} - if cfg != nil && cfg.Headers != nil { - for key, value := range cfg.Headers { - headers[key] = value - } + var headers map[string]string + if cfg != nil { + headers = cfg.Headers } if hasProviderAuthHeader(providerID, headers) { return true @@ -671,27 +666,8 @@ func (oc *AIClient) describeImageWithEntry( if actualMime == "" { actualMime = mimeType } - headers := mergeMediaHeaders(capCfg, entry) - apiKey := oc.resolveMediaProviderAPIKey("google", entry.Profile, entry.PreferredProfile) - if apiKey == "" && !hasProviderAuthHeader("google", headers) { - return nil, errors.New("missing API key for google image understanding") - } - request := mediaImageRequest{ - APIKey: apiKey, - BaseURL: resolveMediaBaseURL(capCfg, entry), - Headers: headers, - Model: strings.TrimSpace(entry.Model), - Prompt: prompt, - MimeType: actualMime, - Data: data, - Timeout: resolveMediaTimeoutSeconds(entry.TimeoutSeconds, capCfg, defaultTimeoutSecondsByCapability[MediaCapabilityImage]), - } - text, err := callGeminiForCapability(ctx, request, MediaCapabilityImage) - if err != nil { - return nil, err - } - text = truncateText(text, maxChars) - return buildMediaOutput(MediaCapabilityImage, text, "google", entry.Model, attachmentIndex), nil + timeout := resolveMediaTimeoutSeconds(entry.TimeoutSeconds, capCfg, defaultTimeoutSecondsByCapability[MediaCapabilityImage]) + return oc.callGeminiMediaCapability(ctx, MediaCapabilityImage, entry, capCfg, data, actualMime, prompt, timeout, maxChars, attachmentIndex) } rawData, actualMime, err := oc.downloadMediaBytes(ctx, mediaURL, encryptedFile, maxBytes, mimeType) @@ -705,9 +681,9 @@ func (oc *AIClient) describeImageWithEntry( actualMime = "image/jpeg" } b64Data := base64.StdEncoding.EncodeToString(rawData) - dataURL := bridgesdk.BuildDataURL(actualMime, b64Data) + dataURL := BuildDataURL(actualMime, b64Data) - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( + ctxPrompt := UserPromptContext( PromptBlock{ Type: PromptBlockText, Text: prompt, @@ -717,7 +693,7 @@ func (oc *AIClient) describeImageWithEntry( ImageURL: dataURL, MimeType: actualMime, }, - )} + ) modelIDForAPI := oc.modelIDForAPI(ResolveAlias(modelID)) var resp *GenerateResponse if entryProvider == "openrouter" { @@ -849,45 +825,31 @@ func (oc *AIClient) describeVideoWithEntry( return nil, errors.New("video payload exceeds base64 limit") } - if providerID == "openrouter" { - modelID := strings.TrimSpace(entry.Model) - if modelID == "" { - return nil, errors.New("video understanding requires model id") - } - videoB64 := base64.StdEncoding.EncodeToString(data) - - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( - PromptBlock{ - Type: PromptBlockText, - Text: prompt, - }, - PromptBlock{ - Type: PromptBlockVideo, - VideoB64: videoB64, - MimeType: actualMime, - }, - )} - modelIDForAPI := oc.modelIDForAPI(ResolveAlias(modelID)) - var resp *GenerateResponse - resp, err = oc.generateWithOpenRouter(ctx, modelIDForAPI, ctxPrompt, capCfg, entry) - if err != nil { - return nil, err - } - text := strings.TrimSpace(resp.Content) - text = truncateText(text, maxChars) - return buildMediaOutput(MediaCapabilityVideo, text, entry.Provider, modelID, attachmentIndex), nil - } if providerID != "google" { return nil, fmt.Errorf("unsupported video provider: %s", providerID) } + return oc.callGeminiMediaCapability(ctx, MediaCapabilityVideo, entry, capCfg, data, actualMime, prompt, timeout, maxChars, attachmentIndex) +} + +func (oc *AIClient) callGeminiMediaCapability( + ctx context.Context, + capability MediaUnderstandingCapability, + entry MediaUnderstandingModelConfig, + capCfg *MediaUnderstandingConfig, + data []byte, + actualMime string, + prompt string, + timeout time.Duration, + maxChars int, + attachmentIndex int, +) (*MediaUnderstandingOutput, error) { headers := mergeMediaHeaders(capCfg, entry) - apiKey := oc.resolveMediaProviderAPIKey(providerID, entry.Profile, entry.PreferredProfile) - if apiKey == "" && !hasProviderAuthHeader(providerID, headers) { - return nil, fmt.Errorf("missing API key for %s video description", providerID) + apiKey := oc.resolveMediaProviderAPIKey("google", entry.Profile, entry.PreferredProfile) + if apiKey == "" && !hasProviderAuthHeader("google", headers) { + return nil, fmt.Errorf("missing API key for google %s", capability) } - - request := mediaVideoRequest{ + request := mediaRequestBase{ APIKey: apiKey, BaseURL: resolveMediaBaseURL(capCfg, entry), Headers: headers, @@ -897,13 +859,13 @@ func (oc *AIClient) describeVideoWithEntry( Data: data, Timeout: timeout, } - text, err := callGeminiForCapability(ctx, request, MediaCapabilityVideo) + text, err := callGeminiForCapability(ctx, request, capability) if err != nil { return nil, err } text = strings.TrimSpace(text) text = truncateText(text, maxChars) - return buildMediaOutput(MediaCapabilityVideo, text, providerID, entry.Model, attachmentIndex), nil + return buildMediaOutput(capability, text, "google", entry.Model, attachmentIndex), nil } func (oc *AIClient) generateWithOpenRouter( @@ -926,9 +888,6 @@ func (oc *AIClient) generateWithOpenRouter( Context: promptContext, MaxCompletionTokens: defaultImageUnderstandingLimit, } - if bridgesdk.PromptContextHasBlockType(promptContext.PromptContext, PromptBlockAudio, PromptBlockVideo) { - return provider.generateChatCompletions(ctx, params) - } return provider.Generate(ctx, params) } @@ -953,10 +912,7 @@ func (oc *AIClient) resolveOpenRouterMediaConfig( if baseURL == "" { baseURL = resolveOpenRouterMediaBaseURL(oc) } - pdfEngine = oc.connector.Config.Providers.OpenRouter.DefaultPDFEngine - if pdfEngine == "" { - pdfEngine = "mistral-ocr" - } + pdfEngine = oc.defaultPDFEngine() if oc.UserLogin != nil && oc.UserLogin.User != nil && oc.UserLogin.User.MXID != "" { userID = oc.UserLogin.User.MXID.String() } diff --git a/bridges/ai/media_understanding_runner_openai_test.go b/bridges/ai/media_understanding_runner_openai_test.go index 709a71ea..14d595b5 100644 --- a/bridges/ai/media_understanding_runner_openai_test.go +++ b/bridges/ai/media_understanding_runner_openai_test.go @@ -26,8 +26,10 @@ func TestResolveMediaProviderAPIKeyOpenAIMagicProxyUsesLoginToken(t *testing.T) meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } client := newMediaTestClient(meta, &OpenAIConnector{}) @@ -39,8 +41,10 @@ func TestResolveMediaProviderAPIKeyOpenAIMagicProxyUsesLoginToken(t *testing.T) func TestResolveOpenAIMediaBaseURLMagicProxyUsesOpenAIServicePath(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "tok", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "tok", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } client := newMediaTestClient(meta, &OpenAIConnector{}) @@ -54,11 +58,7 @@ func TestResolveOpenRouterMediaConfigUsesEntryOverrides(t *testing.T) { client := newMediaTestClient(&UserLoginMetadata{Provider: ProviderOpenAI}, &OpenAIConnector{ Config: Config{ - Providers: ProvidersConfig{ - OpenRouter: ProviderConfig{ - DefaultPDFEngine: "native", - }, - }, + Agents: &AgentsConfig{Defaults: &AgentDefaultsConfig{PDFEngine: "mistral-ocr"}}, }, }) @@ -99,8 +99,8 @@ func TestResolveOpenRouterMediaConfigUsesEntryOverrides(t *testing.T) { if headers["X-Title"] != openRouterAppTitle { t.Fatalf("expected default OpenRouter title header, got %#v", headers) } - if pdfEngine != "native" { - t.Fatalf("expected configured PDF engine, got %q", pdfEngine) + if pdfEngine != "mistral-ocr" { + t.Fatalf("expected default PDF engine, got %q", pdfEngine) } } diff --git a/bridges/ai/messages.go b/bridges/ai/messages.go index 0504b278..66a0cb54 100644 --- a/bridges/ai/messages.go +++ b/bridges/ai/messages.go @@ -1,32 +1,83 @@ package ai -import bridgesdk "github.com/beeper/agentremote/sdk" +import "strings" -type PromptRole = bridgesdk.PromptRole +type PromptRole string const ( - PromptRoleUser PromptRole = bridgesdk.PromptRoleUser - PromptRoleAssistant PromptRole = bridgesdk.PromptRoleAssistant - PromptRoleToolResult PromptRole = bridgesdk.PromptRoleToolResult + PromptRoleUser PromptRole = "user" + PromptRoleAssistant PromptRole = "assistant" + PromptRoleToolResult PromptRole = "tool_result" ) -type PromptBlockType = bridgesdk.PromptBlockType +type PromptBlockType string const ( - PromptBlockText PromptBlockType = bridgesdk.PromptBlockText - PromptBlockImage PromptBlockType = bridgesdk.PromptBlockImage - PromptBlockFile PromptBlockType = bridgesdk.PromptBlockFile - PromptBlockThinking PromptBlockType = bridgesdk.PromptBlockThinking - PromptBlockToolCall PromptBlockType = bridgesdk.PromptBlockToolCall - PromptBlockAudio PromptBlockType = bridgesdk.PromptBlockAudio - PromptBlockVideo PromptBlockType = bridgesdk.PromptBlockVideo + PromptBlockText PromptBlockType = "text" + PromptBlockImage PromptBlockType = "image" + PromptBlockThinking PromptBlockType = "thinking" + PromptBlockToolCall PromptBlockType = "tool_call" ) -type PromptBlock = bridgesdk.PromptBlock -type PromptMessage = bridgesdk.PromptMessage +type PromptBlock struct { + Type PromptBlockType -// PromptContext extends the shared provider-facing prompt model with bridge-local tool definitions. + Text string + + ImageURL string + ImageB64 string + MimeType string + + ToolCallID string + ToolName string + ToolCallArguments string +} + +type PromptMessage struct { + Role PromptRole + Blocks []PromptBlock + ToolCallID string + ToolName string + IsError bool +} + +func (m PromptMessage) text(includeThinking bool) string { + var sb strings.Builder + for _, block := range m.Blocks { + switch block.Type { + case PromptBlockText: + if block.Text != "" { + if sb.Len() > 0 { + sb.WriteByte('\n') + } + sb.WriteString(block.Text) + } + case PromptBlockThinking: + if !includeThinking || block.Text == "" { + continue + } + if block.Text != "" { + if sb.Len() > 0 { + sb.WriteByte('\n') + } + sb.WriteString(block.Text) + } + } + } + return sb.String() +} + +func (m PromptMessage) Text() string { + return m.text(true) +} + +func (m PromptMessage) VisibleText() string { + return m.text(false) +} + +// PromptContext is the bridge-local prompt envelope used throughout bridges/ai. type PromptContext struct { - bridgesdk.PromptContext - Tools []ToolDefinition + SystemPrompt string + Messages []PromptMessage + Tools []ToolDefinition } diff --git a/bridges/ai/messages_responses_input_test.go b/bridges/ai/messages_responses_input_test.go index d414b784..081bf01d 100644 --- a/bridges/ai/messages_responses_input_test.go +++ b/bridges/ai/messages_responses_input_test.go @@ -4,15 +4,12 @@ import ( "testing" "github.com/openai/openai-go/v3/responses" - - bridgesdk "github.com/beeper/agentremote/sdk" ) func TestPromptContextToResponsesInput_MultimodalUser(t *testing.T) { - input := bridgesdk.PromptContextToResponsesInput(bridgesdk.UserPromptContext( + input := promptContextToResponsesInput(UserPromptContext( PromptBlock{Type: PromptBlockText, Text: "hello"}, PromptBlock{Type: PromptBlockImage, ImageB64: "aGVsbG8=", MimeType: "image/png"}, - PromptBlock{Type: PromptBlockFile, FileB64: "cGRm", Filename: "document.pdf"}, )) if len(input) != 1 { t.Fatalf("expected 1 input item, got %d", len(input)) @@ -33,7 +30,6 @@ func TestPromptContextToResponsesInput_MultimodalUser(t *testing.T) { foundText := false foundImage := false - foundFile := false for _, part := range parts { if part.OfInputText != nil { foundText = true @@ -47,18 +43,31 @@ func TestPromptContextToResponsesInput_MultimodalUser(t *testing.T) { t.Fatalf("expected image part data URL to preserve content, got %#v", part.OfInputImage.ImageURL.Value) } } - if part.OfInputFile != nil { - foundFile = true - if part.OfInputFile.Filename.Value != "document.pdf" { - t.Fatalf("expected file part filename document.pdf, got %#v", part.OfInputFile.Filename.Value) - } - if part.OfInputFile.FileData.Value != "cGRm" { - t.Fatalf("expected file part data to preserve content, got %#v", part.OfInputFile.FileData.Value) - } - } } - if !foundText || !foundImage || !foundFile { - t.Fatalf("expected text, image, and file parts (got text=%v image=%v file=%v)", foundText, foundImage, foundFile) + if !foundText || !foundImage { + t.Fatalf("expected text and image parts (got text=%v image=%v)", foundText, foundImage) + } +} + +func TestPromptContextToResponsesInput_AssistantOmitsThinkingBlocks(t *testing.T) { + input := promptContextToResponsesInput(PromptContext{ + Messages: []PromptMessage{{ + Role: PromptRoleAssistant, + Blocks: []PromptBlock{ + {Type: PromptBlockThinking, Text: "internal analysis"}, + {Type: PromptBlockText, Text: "visible reply"}, + }, + }}, + }) + if len(input) != 1 { + t.Fatalf("expected 1 input item, got %d", len(input)) + } + item := input[0].OfMessage + if item == nil { + t.Fatalf("expected assistant message input") + } + if got := item.Content.OfString.Value; got != "visible reply" { + t.Fatalf("expected only visible reply text, got %q", got) } } diff --git a/bridges/ai/metadata.go b/bridges/ai/metadata.go index e6492b93..1afe73ed 100644 --- a/bridges/ai/metadata.go +++ b/bridges/ai/metadata.go @@ -4,6 +4,7 @@ import ( "encoding/json" "maps" "slices" + "strings" "go.mau.fi/util/jsontime" "go.mau.fi/util/random" @@ -53,6 +54,13 @@ type UserProfile struct { CustomInstructions string `json:"custom_instructions,omitempty"` } +// LoginCredentials stores the per-login credentials and service-specific tokens. +type LoginCredentials struct { + APIKey string `json:"api_key,omitempty"` + BaseURL string `json:"base_url,omitempty"` + ServiceTokens *ServiceTokens `json:"service_tokens,omitempty"` +} + // ServiceTokens stores optional per-login credentials for external services. type ServiceTokens struct { OpenAI string `json:"openai,omitempty"` @@ -109,26 +117,22 @@ type BuiltinAlwaysAllowRule struct { // UserLoginMetadata is stored on each login row to keep per-user settings. type UserLoginMetadata struct { - Provider string `json:"provider,omitempty"` // Selected provider (beeper, openai, openrouter) - APIKey string `json:"api_key,omitempty"` - BaseURL string `json:"base_url,omitempty"` // Per-user API endpoint - TitleGenerationModel string `json:"title_generation_model,omitempty"` // Model to use for generating chat titles - Agents *bool `json:"agents,omitempty"` // Nil/true enables agents, false limits login to model rooms - NextChatIndex int `json:"next_chat_index,omitempty"` - DefaultChatPortalID string `json:"default_chat_portal_id,omitempty"` - ModelCache *ModelCache `json:"model_cache,omitempty"` - ChatsSynced bool `json:"chats_synced,omitempty"` // True after initial bootstrap completed successfully - Gravatar *GravatarState `json:"gravatar,omitempty"` - Timezone string `json:"timezone,omitempty"` - Profile *UserProfile `json:"profile,omitempty"` + Provider string `json:"provider,omitempty"` // Selected provider (beeper, openai, openrouter) + Credentials *LoginCredentials `json:"credentials,omitempty"` + TitleGenerationModel string `json:"title_generation_model,omitempty"` // Model to use for generating chat titles + Agents *bool `json:"agents,omitempty"` // Nil/true enables agents, false limits login to model rooms + NextChatIndex int `json:"next_chat_index,omitempty"` + DefaultChatPortalID string `json:"default_chat_portal_id,omitempty"` + ModelCache *ModelCache `json:"model_cache,omitempty"` + ChatsSynced bool `json:"chats_synced,omitempty"` // True after initial bootstrap completed successfully + Gravatar *GravatarState `json:"gravatar,omitempty"` + Timezone string `json:"timezone,omitempty"` + Profile *UserProfile `json:"profile,omitempty"` // FileAnnotationCache stores parsed PDF content from OpenRouter's file-parser plugin // Key is the file hash (SHA256), pruned after 7 days FileAnnotationCache map[string]FileAnnotation `json:"file_annotation_cache,omitempty"` - // Optional per-login tokens for external services - ServiceTokens *ServiceTokens `json:"service_tokens,omitempty"` - // Tool approval rules (e.g. "always allow" decisions for MCP approvals or dangerous builtin tools). ToolApprovals *ToolApprovalsConfig `json:"tool_approvals,omitempty"` @@ -146,6 +150,101 @@ type UserLoginMetadata struct { LastErrorAt int64 `json:"last_error_at,omitempty"` // Unix timestamp } +func loginCredentials(meta *UserLoginMetadata) *LoginCredentials { + if meta == nil { + return nil + } + return meta.Credentials +} + +func ensureLoginCredentials(meta *UserLoginMetadata) *LoginCredentials { + if meta == nil { + return nil + } + if meta.Credentials == nil { + meta.Credentials = &LoginCredentials{} + } + return meta.Credentials +} + +func loginCredentialAPIKey(meta *UserLoginMetadata) string { + if creds := loginCredentials(meta); creds != nil { + return strings.TrimSpace(creds.APIKey) + } + return "" +} + +func loginCredentialBaseURL(meta *UserLoginMetadata) string { + if creds := loginCredentials(meta); creds != nil { + return strings.TrimSpace(creds.BaseURL) + } + return "" +} + +func loginCredentialServiceTokens(meta *UserLoginMetadata) *ServiceTokens { + if creds := loginCredentials(meta); creds != nil { + return creds.ServiceTokens + } + return nil +} + +func loginCredentialsEmpty(creds *LoginCredentials) bool { + if creds == nil { + return true + } + return strings.TrimSpace(creds.APIKey) == "" && + strings.TrimSpace(creds.BaseURL) == "" && + serviceTokensEmpty(creds.ServiceTokens) +} + +func cloneServiceTokens(src *ServiceTokens) *ServiceTokens { + if src == nil { + return nil + } + clone := *src + if src.DesktopAPIInstances != nil { + clone.DesktopAPIInstances = maps.Clone(src.DesktopAPIInstances) + } + if src.MCPServers != nil { + clone.MCPServers = maps.Clone(src.MCPServers) + } + return &clone +} + +func serviceTokensEmpty(tokens *ServiceTokens) bool { + if tokens == nil { + return true + } + if len(tokens.DesktopAPIInstances) > 0 { + for _, instance := range tokens.DesktopAPIInstances { + if strings.TrimSpace(instance.Token) != "" || strings.TrimSpace(instance.BaseURL) != "" { + return false + } + } + } + if len(tokens.MCPServers) > 0 { + for _, server := range tokens.MCPServers { + if strings.TrimSpace(server.Transport) != "" || + strings.TrimSpace(server.Endpoint) != "" || + strings.TrimSpace(server.Command) != "" || + len(server.Args) > 0 || + strings.TrimSpace(server.Token) != "" || + strings.TrimSpace(server.AuthURL) != "" || + strings.TrimSpace(server.AuthType) != "" || + strings.TrimSpace(server.Kind) != "" || + server.Connected { + return false + } + } + } + return strings.TrimSpace(tokens.OpenAI) == "" && + strings.TrimSpace(tokens.OpenRouter) == "" && + strings.TrimSpace(tokens.Exa) == "" && + strings.TrimSpace(tokens.Brave) == "" && + strings.TrimSpace(tokens.Perplexity) == "" && + strings.TrimSpace(tokens.DesktopAPI) == "" +} + // HeartbeatState tracks last heartbeat delivery for dedupe. type HeartbeatState struct { LastHeartbeatText string `json:"last_heartbeat_text,omitempty"` @@ -212,6 +311,32 @@ func (m *PortalMetadata) SetModuleMeta(key string, value any) { m.ModuleMeta[key] = value } +func (m *PortalMetadata) ModuleMetaValue(key string) any { + if m == nil || m.ModuleMeta == nil { + return nil + } + return m.ModuleMeta[key] +} + +func (m *PortalMetadata) SetModuleMetaValue(key string, value any) { + m.SetModuleMeta(key, value) +} + +func (m *PortalMetadata) AgentID() string { + return resolveAgentID(m) +} + +func (m *PortalMetadata) CompactionCounter() int { + if m == nil { + return 0 + } + return m.CompactionCount +} + +func (m *PortalMetadata) InternalRoom() bool { + return isModuleInternalRoom(m) +} + func cloneUserLoginMetadata(src *UserLoginMetadata) (*UserLoginMetadata, error) { if src == nil { return &UserLoginMetadata{}, nil @@ -227,59 +352,6 @@ func cloneUserLoginMetadata(src *UserLoginMetadata) (*UserLoginMetadata, error) return &clone, nil } -func mergeServiceTokens(existing, incoming *ServiceTokens) *ServiceTokens { - if incoming == nil { - return existing - } - if existing == nil { - clone := *incoming - if incoming.DesktopAPIInstances != nil { - clone.DesktopAPIInstances = maps.Clone(incoming.DesktopAPIInstances) - } - if incoming.MCPServers != nil { - clone.MCPServers = maps.Clone(incoming.MCPServers) - } - return &clone - } - - merged := *existing - if incoming.OpenAI != "" { - merged.OpenAI = incoming.OpenAI - } - if incoming.OpenRouter != "" { - merged.OpenRouter = incoming.OpenRouter - } - if incoming.Exa != "" { - merged.Exa = incoming.Exa - } - if incoming.Brave != "" { - merged.Brave = incoming.Brave - } - if incoming.Perplexity != "" { - merged.Perplexity = incoming.Perplexity - } - if incoming.DesktopAPI != "" { - merged.DesktopAPI = incoming.DesktopAPI - } - if len(incoming.DesktopAPIInstances) > 0 { - if merged.DesktopAPIInstances == nil { - merged.DesktopAPIInstances = make(map[string]DesktopAPIInstance, len(incoming.DesktopAPIInstances)) - } - for key, value := range incoming.DesktopAPIInstances { - merged.DesktopAPIInstances[key] = value - } - } - if len(incoming.MCPServers) > 0 { - if merged.MCPServers == nil { - merged.MCPServers = make(map[string]MCPServerConfig, len(incoming.MCPServers)) - } - for key, value := range incoming.MCPServers { - merged.MCPServers[key] = value - } - } - return &merged -} - func agentsEnabled(meta *UserLoginMetadata) bool { if meta == nil || meta.Agents == nil { return false diff --git a/bridges/ai/model_catalog_test.go b/bridges/ai/model_catalog_test.go index 810817df..fc3d56df 100644 --- a/bridges/ai/model_catalog_test.go +++ b/bridges/ai/model_catalog_test.go @@ -10,7 +10,9 @@ func TestImplicitModelCatalogEntries_MagicProxySeedsCatalog(t *testing.T) { // Magic Proxy logins store the API key on the login metadata. meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "mp-token", + Credentials: &LoginCredentials{ + APIKey: "mp-token", + }, } entries := oc.implicitModelCatalogEntries(meta) @@ -18,3 +20,40 @@ func TestImplicitModelCatalogEntries_MagicProxySeedsCatalog(t *testing.T) { t.Fatalf("expected non-empty model catalog entries for magic_proxy, got 0") } } + +func TestImplicitModelCatalogEntries_OpenAILoginUsesManifestMetadata(t *testing.T) { + oc := &AIClient{ + connector: &OpenAIConnector{}, + } + meta := &UserLoginMetadata{ + Provider: ProviderOpenAI, + Credentials: &LoginCredentials{ + APIKey: "openai-token", + }, + } + + entries := oc.implicitModelCatalogEntries(meta) + if len(entries) == 0 { + t.Fatalf("expected non-empty model catalog entries for openai, got 0") + } + + entry := findModelCatalogEntry(entries, ProviderOpenAI, "gpt-5.4-mini") + if entry == nil { + t.Fatal("expected gpt-5.4-mini entry in openai catalog") + } + + manifestInfo, ok := ModelManifest.Models["openai/gpt-5.4-mini"] + if !ok { + t.Fatal("expected gpt-5.4-mini in manifest") + } + + if entry.ContextWindow != manifestInfo.ContextWindow { + t.Fatalf("context window = %d, want %d", entry.ContextWindow, manifestInfo.ContextWindow) + } + if entry.MaxOutputTokens != manifestInfo.MaxOutputTokens { + t.Fatalf("max output tokens = %d, want %d", entry.MaxOutputTokens, manifestInfo.MaxOutputTokens) + } + if !catalogInputIncludes(entry, "image") || !catalogInputIncludes(entry, "pdf") { + t.Fatalf("expected openai catalog entry to retain manifest modalities, got %#v", entry.Input) + } +} diff --git a/bridges/ai/pending_event.go b/bridges/ai/pending_event.go new file mode 100644 index 00000000..e2143ba9 --- /dev/null +++ b/bridges/ai/pending_event.go @@ -0,0 +1,55 @@ +package ai + +import "maunium.net/go/mautrix/event" + +// snapshotPendingEvent copies only the event fields that queued/goroutine-based +// reply targeting and status propagation depend on. +func snapshotPendingEvent(evt *event.Event) *event.Event { + if evt == nil { + return nil + } + cloned := &event.Event{ + ID: evt.ID, + Type: evt.Type, + Sender: evt.Sender, + } + if len(evt.Content.Raw) > 0 { + cloned.Content.Raw = clonePendingRawMap(evt.Content.Raw) + } + return cloned +} + +func clonePendingRawMap(src map[string]any) map[string]any { + if len(src) == 0 { + return nil + } + cloned := make(map[string]any, len(src)) + for k, v := range src { + cloned[k] = clonePendingRawValue(v) + } + return cloned +} + +func clonePendingRawSlice(src []any) []any { + if len(src) == 0 { + return nil + } + cloned := make([]any, len(src)) + for i, v := range src { + cloned[i] = clonePendingRawValue(v) + } + return cloned +} + +func clonePendingRawValue(v any) any { + switch typed := v.(type) { + case map[string]any: + return clonePendingRawMap(typed) + case []any: + return clonePendingRawSlice(typed) + case []byte: + return append([]byte(nil), typed...) + default: + return v + } +} diff --git a/bridges/ai/pending_event_test.go b/bridges/ai/pending_event_test.go new file mode 100644 index 00000000..2ed950c2 --- /dev/null +++ b/bridges/ai/pending_event_test.go @@ -0,0 +1,49 @@ +package ai + +import ( + "context" + "testing" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestSnapshotPendingEventPreservesReplyTargetAfterSourceMutation(t *testing.T) { + original := &event.Event{ + ID: id.EventID("$evt"), + Sender: id.UserID("@alice:example.com"), + Content: event.Content{ + Raw: map[string]any{ + "m.relates_to": map[string]any{ + "m.in_reply_to": map[string]any{ + "event_id": "$parent", + }, + }, + }, + }, + } + + snapshot := snapshotPendingEvent(original) + original.ID = id.EventID("$mutated") + original.Sender = id.UserID("@bob:example.com") + original.Content.Raw["m.relates_to"].(map[string]any)["m.in_reply_to"].(map[string]any)["event_id"] = "$other-parent" + + oc := &AIClient{} + meta := modelModeTestMeta("openai/gpt-5.2") + prep, cleanup := oc.prepareStreamingRun(context.Background(), zerolog.Nop(), snapshot, nil, meta) + defer cleanup() + + if prep.State == nil || prep.State.turn == nil { + t.Fatalf("expected streaming turn to be prepared") + } + if got := prep.State.turn.Source().EventID; got != "$evt" { + t.Fatalf("expected snapped source event id %q, got %q", "$evt", got) + } + if got := prep.State.turn.Source().SenderID; got != "@alice:example.com" { + t.Fatalf("expected snapped sender id %q, got %q", "@alice:example.com", got) + } + if got := prep.State.replyTarget.ReplyTo; got != id.EventID("$parent") { + t.Fatalf("expected snapped reply target %q, got %q", "$parent", got) + } +} diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index bea251ac..aee69c52 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -4,9 +4,9 @@ import ( "context" "slices" "strings" + "sync" "time" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/id" @@ -25,6 +25,7 @@ type pendingQueueItem struct { } type pendingQueue struct { + mu sync.Mutex items []pendingQueueItem draining bool lastEnqueuedAt int64 @@ -58,6 +59,7 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe } oc.pendingQueues[roomID] = queue } else { + queue.mu.Lock() queue.mode = settings.Mode if settings.DebounceMs >= 0 { queue.debounceMs = settings.DebounceMs @@ -68,6 +70,7 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe if settings.DropPolicy != "" { queue.dropPolicy = settings.DropPolicy } + queue.mu.Unlock() } return queue } @@ -87,6 +90,8 @@ func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, if queue == nil { return false } + queue.mu.Lock() + defer queue.mu.Unlock() for _, existing := range queue.items { if pendingQueueItemsConflict(item, existing) { @@ -152,6 +157,8 @@ func (oc *AIClient) popQueueItems(roomID id.RoomID, count int) []pendingQueueIte if queue == nil || len(queue.items) == 0 || count <= 0 { return nil } + queue.mu.Lock() + defer queue.mu.Unlock() if count > len(queue.items) { count = len(queue.items) } @@ -171,12 +178,33 @@ func (oc *AIClient) getQueueSnapshot(roomID id.RoomID) *pendingQueue { if queue == nil { return nil } + queue.mu.Lock() + defer queue.mu.Unlock() clone := *queue clone.items = slices.Clone(queue.items) clone.summaryLines = slices.Clone(queue.summaryLines) + if queue.lastItem != nil { + lastItem := *queue.lastItem + clone.lastItem = &lastItem + } return &clone } +func (oc *AIClient) roomHasPendingQueueWork(roomID id.RoomID) bool { + if oc == nil || roomID == "" { + return false + } + oc.pendingQueuesMu.Lock() + defer oc.pendingQueuesMu.Unlock() + queue := oc.pendingQueues[roomID] + if queue == nil { + return false + } + queue.mu.Lock() + defer queue.mu.Unlock() + return queue.draining || len(queue.items) > 0 || queue.droppedCount > 0 +} + func (oc *AIClient) consumeQueueSummary(roomID id.RoomID, noun string) string { oc.pendingQueuesMu.Lock() defer oc.pendingQueuesMu.Unlock() @@ -184,6 +212,8 @@ func (oc *AIClient) consumeQueueSummary(roomID id.RoomID, noun string) string { if queue == nil || queue.droppedCount == 0 { return "" } + queue.mu.Lock() + defer queue.mu.Unlock() summary := buildQueueSummaryPrompt(queue, noun) queue.droppedCount = 0 queue.summaryLines = nil @@ -324,47 +354,21 @@ func (oc *AIClient) getSteeringMessages(roomID id.RoomID) []string { return messages } -func buildSteeringUserMessages(prompts []string) []openai.ChatCompletionMessageParamUnion { +func buildSteeringPromptMessages(prompts []string) []PromptMessage { if len(prompts) == 0 { return nil } - messages := make([]openai.ChatCompletionMessageParamUnion, 0, len(prompts)) + messages := make([]PromptMessage, 0, len(prompts)) for _, prompt := range prompts { prompt = strings.TrimSpace(prompt) if prompt == "" { continue } - messages = append(messages, openai.UserMessage(prompt)) + messages = append(messages, newUserTextPromptMessage(prompt)) } return messages } -func (oc *AIClient) getFollowUpMessages(roomID id.RoomID) []openai.ChatCompletionMessageParamUnion { - if oc == nil || roomID == "" { - return nil - } - snapshot := oc.getQueueSnapshot(roomID) - if snapshot == nil { - return nil - } - behavior := airuntime.ResolveQueueBehavior(snapshot.mode) - if !behavior.Followup { - return nil - } - candidate, _ := oc.takePendingQueueDispatchCandidate(roomID, true) - if candidate == nil || len(candidate.items) == 0 { - return nil - } - for _, item := range candidate.items { - oc.registerRoomRunPendingItem(roomID, item) - } - _, prompt, ok := preparePendingQueueDispatchCandidate(candidate) - if !ok { - return nil - } - return buildSteeringUserMessages([]string{prompt}) -} - func (oc *AIClient) markQueueDraining(roomID id.RoomID) bool { oc.pendingQueuesMu.Lock() defer oc.pendingQueuesMu.Unlock() @@ -372,6 +376,11 @@ func (oc *AIClient) markQueueDraining(roomID id.RoomID) bool { if queue == nil || queue.draining { return false } + queue.mu.Lock() + defer queue.mu.Unlock() + if queue.draining { + return false + } queue.draining = true return true } @@ -383,6 +392,8 @@ func (oc *AIClient) clearQueueDraining(roomID id.RoomID) { if queue == nil { return } + queue.mu.Lock() + defer queue.mu.Unlock() queue.draining = false if len(queue.items) == 0 && queue.droppedCount == 0 { delete(oc.pendingQueues, roomID) diff --git a/bridges/ai/prompt_builder.go b/bridges/ai/prompt_builder.go new file mode 100644 index 00000000..e55ccad8 --- /dev/null +++ b/bridges/ai/prompt_builder.go @@ -0,0 +1,290 @@ +package ai + +import ( + "context" + "fmt" + "slices" + "strings" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type historyReplayMode string + +const ( + historyReplayNormal historyReplayMode = "normal" + historyReplayRegen historyReplayMode = "regenerate" + historyReplayRewrite historyReplayMode = "rewrite" +) + +type historyReplayOptions struct { + mode historyReplayMode + targetMessageID networkid.MessageID + excludeMessageID networkid.MessageID +} + +type currentTurnTextOptions struct { + rawEventContent map[string]any + includeLinkScope bool + prepend []string + append []string +} + +type turnAttachmentOptions struct { + mediaURL string + mimeType string + encryptedFile *event.EncryptedFileInfo + mediaType pendingMessageType +} + +type currentTurnPromptOptions struct { + currentTurnTextOptions + leadingBlocks []PromptBlock + attachment *turnAttachmentOptions +} + +func joinPromptFragments(parts ...string) string { + var filtered []string + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + filtered = append(filtered, part) + } + } + return strings.TrimSpace(strings.Join(filtered, "\n\n")) +} + +func (oc *AIClient) fetchHistoryRowsWithExtra( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + extra int, +) (*historyLoadResult, error) { + historyLimit := oc.historyLimit(ctx, portal, meta) + if historyLimit <= 0 { + return nil, nil + } + if extra > 0 { + historyLimit += extra + } + resetAt := int64(0) + if meta != nil { + resetAt = meta.SessionResetAt + } + history, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(ctx, portal.PortalKey, historyLimit) + if err != nil { + return nil, err + } + return &historyLoadResult{ + rows: history, + hasVision: oc.getModelCapabilitiesForMeta(ctx, meta).SupportsVision, + resetAt: resetAt, + }, nil +} + +func (oc *AIClient) replayHistoryMessages( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + opts historyReplayOptions, +) ([]PromptMessage, error) { + extra := 0 + if opts.mode == historyReplayRegen { + extra = 2 + } + hr, err := oc.fetchHistoryRowsWithExtra(ctx, portal, meta, extra) + if err != nil { + return nil, err + } + if hr == nil { + return nil, nil + } + + type replayCandidate struct { + row *database.Message + meta *MessageMetadata + } + + candidates := make([]replayCandidate, 0, len(hr.rows)) + for _, row := range hr.rows { + if opts.excludeMessageID != "" && row.ID == opts.excludeMessageID { + continue + } + msgMeta := messageMeta(row) + if opts.mode == historyReplayRewrite && row.ID == opts.targetMessageID { + candidates = append(candidates, replayCandidate{row: row, meta: msgMeta}) + continue + } + if !shouldIncludeInHistory(msgMeta) { + continue + } + if hr.resetAt > 0 && row.Timestamp.UnixMilli() < hr.resetAt { + continue + } + candidates = append(candidates, replayCandidate{row: row, meta: msgMeta}) + } + + skipUserID := networkid.MessageID("") + skipAssistantID := networkid.MessageID("") + if opts.mode == historyReplayRegen { + for _, candidate := range candidates { + if skipUserID == "" && candidate.meta != nil && candidate.meta.Role == string(PromptRoleUser) { + skipUserID = candidate.row.ID + continue + } + if skipAssistantID == "" && candidate.meta != nil && candidate.meta.Role == string(PromptRoleAssistant) { + skipAssistantID = candidate.row.ID + } + if skipUserID != "" && skipAssistantID != "" { + break + } + } + } + + var messages []PromptMessage + for i := len(candidates) - 1; i >= 0; i-- { + candidate := candidates[i] + if opts.mode == historyReplayRewrite && candidate.row.ID == opts.targetMessageID { + break + } + if candidate.row.ID == skipUserID || candidate.row.ID == skipAssistantID { + continue + } + injectImages := hr.hasVision && i < maxHistoryImageMessages + messages = append(messages, oc.historyMessageBundle(ctx, candidate.meta, injectImages)...) + } + return messages, nil +} + +func (oc *AIClient) buildCurrentTurnText( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + userText string, + eventID id.EventID, + opts currentTurnTextOptions, +) (PromptContext, string, error) { + result, err := oc.prepareInboundPromptContext(ctx, portal, meta, userText, eventID) + if err != nil { + return PromptContext{}, "", err + } + + prepend := slices.Clone(opts.prepend) + if portal != nil && portal.MXID != "" { + reactionFeedback := DrainReactionFeedback(portal.MXID) + if len(reactionFeedback) > 0 { + if feedbackText := FormatReactionFeedback(reactionFeedback); feedbackText != "" { + prepend = append(prepend, feedbackText) + } + } + } + if result.UntrustedPrefix != "" { + prepend = append(prepend, result.UntrustedPrefix) + } + + appendParts := slices.Clone(opts.append) + if opts.includeLinkScope { + if linkContext := oc.buildLinkContext(ctx, userText, opts.rawEventContent); linkContext != "" { + appendParts = append(appendParts, linkContext) + } + } + + body := joinPromptFragments(append(append(prepend, result.ResolvedBody), appendParts...)...) + return result.PromptContext, body, nil +} + +func (oc *AIClient) buildPromptContextForTurn( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + userText string, + eventID id.EventID, + opts currentTurnPromptOptions, +) (PromptContext, error) { + appendFragments := slices.Clone(opts.append) + leadingBlocks := slices.Clone(opts.leadingBlocks) + + if opts.attachment != nil { + attachmentBlocks, attachmentAppend, err := oc.normalizeTurnAttachment(ctx, *opts.attachment) + if err != nil { + return PromptContext{}, err + } + leadingBlocks = append(leadingBlocks, attachmentBlocks...) + appendFragments = append(appendFragments, attachmentAppend...) + } + + textOpts := opts.currentTurnTextOptions + textOpts.append = appendFragments + base, text, err := oc.buildCurrentTurnText(ctx, portal, meta, userText, eventID, textOpts) + if err != nil { + return PromptContext{}, err + } + + blocks := make([]PromptBlock, 0, len(leadingBlocks)+1) + blocks = append(blocks, leadingBlocks...) + if strings.TrimSpace(text) != "" { + blocks = append(blocks, PromptBlock{Type: PromptBlockText, Text: text}) + } + base.Messages = append(base.Messages, PromptMessage{ + Role: PromptRoleUser, + Blocks: blocks, + }) + return base, nil +} + +func (oc *AIClient) normalizeTurnAttachment(ctx context.Context, opts turnAttachmentOptions) ([]PromptBlock, []string, error) { + switch opts.mediaType { + case pendingTypeImage: + b64Data, actualMimeType, err := oc.downloadMediaBase64(ctx, opts.mediaURL, opts.encryptedFile, 20, opts.mimeType) + if err != nil { + return nil, nil, fmt.Errorf("failed to download image: %w", err) + } + return []PromptBlock{{ + Type: PromptBlockImage, + ImageB64: b64Data, + MimeType: actualMimeType, + }}, nil, nil + case pendingTypePDF: + content, truncated, err := oc.downloadPDFFile(ctx, opts.mediaURL, opts.encryptedFile, opts.mimeType) + if err != nil { + return nil, nil, fmt.Errorf("failed to download PDF: %w", err) + } + filename := resolveMediaFileName("document.pdf", "pdf", opts.mediaURL) + return nil, []string{buildTextFileMessage("", false, filename, "application/pdf", content, truncated)}, nil + case pendingTypeAudio: + return nil, nil, fmt.Errorf("audio attachments must be preprocessed into text before prompt assembly") + case pendingTypeVideo: + return nil, nil, fmt.Errorf("video attachments must be preprocessed into text before prompt assembly") + default: + return nil, nil, fmt.Errorf("unsupported media type: %s", opts.mediaType) + } +} + +func (oc *AIClient) buildCurrentTurnWithLinks( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + userText string, + rawEventContent map[string]any, + eventID id.EventID, +) (PromptContext, error) { + return oc.buildPromptContextForTurn(ctx, portal, meta, userText, eventID, currentTurnPromptOptions{ + currentTurnTextOptions: currentTurnTextOptions{ + rawEventContent: rawEventContent, + includeLinkScope: true, + }, + }) +} + +func (oc *AIClient) buildHeartbeatTurnContext( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + prompt string, +) (PromptContext, error) { + return oc.buildPromptContextForTurn(ctx, portal, meta, prompt, "", currentTurnPromptOptions{}) +} diff --git a/bridges/ai/prompt_context_local.go b/bridges/ai/prompt_context_local.go new file mode 100644 index 00000000..8514aa2c --- /dev/null +++ b/bridges/ai/prompt_context_local.go @@ -0,0 +1,375 @@ +package ai + +import ( + "fmt" + "slices" + "strings" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" +) + +func UserPromptContext(blocks ...PromptBlock) PromptContext { + return PromptContext{ + Messages: []PromptMessage{{ + Role: PromptRoleUser, + Blocks: slices.Clone(blocks), + }}, + } +} + +func AppendPromptText(dst *string, text string) { + text = strings.TrimSpace(text) + if text == "" { + return + } + if *dst == "" { + *dst = text + return + } + *dst = strings.TrimSpace(*dst + "\n\n" + text) +} + +func BuildDataURL(mimeType, b64Data string) string { + return fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data) +} + +func resolveBlockImageURL(block PromptBlock) string { + imageURL := strings.TrimSpace(block.ImageURL) + if imageURL == "" && block.ImageB64 != "" { + mimeType := strings.TrimSpace(block.MimeType) + if mimeType == "" { + mimeType = "image/jpeg" + } + imageURL = BuildDataURL(mimeType, block.ImageB64) + } + return imageURL +} + +func promptContextToResponsesInput(ctx PromptContext) responses.ResponseInputParam { + var result responses.ResponseInputParam + for _, msg := range ctx.Messages { + result = append(result, promptMessageToResponsesInputs(msg)...) + } + return result +} + +func promptMessageToResponsesInputs(msg PromptMessage) responses.ResponseInputParam { + switch msg.Role { + case PromptRoleUser: + content := make([]responses.ResponseInputContentUnionParam, 0, len(msg.Blocks)) + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText: + text := strings.TrimSpace(block.Text) + if text == "" { + continue + } + content = append(content, responses.ResponseInputContentUnionParam{ + OfInputText: &responses.ResponseInputTextParam{Text: text}, + }) + case PromptBlockImage: + imageURL := resolveBlockImageURL(block) + if imageURL == "" { + continue + } + content = append(content, responses.ResponseInputContentUnionParam{ + OfInputImage: &responses.ResponseInputImageParam{ + ImageURL: param.NewOpt(imageURL), + }, + }) + } + } + if len(content) == 0 { + return nil + } + return responses.ResponseInputParam{{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleUser, + Content: responses.EasyInputMessageContentUnionParam{OfInputItemContentList: content}, + }, + }} + case PromptRoleAssistant: + var result responses.ResponseInputParam + text := strings.TrimSpace(msg.VisibleText()) + if text != "" { + result = append(result, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.String(text)}, + }, + }) + } + for _, block := range msg.Blocks { + if block.Type != PromptBlockToolCall || strings.TrimSpace(block.ToolCallID) == "" || strings.TrimSpace(block.ToolName) == "" { + continue + } + args := strings.TrimSpace(block.ToolCallArguments) + if args == "" { + args = "{}" + } + result = append(result, responses.ResponseInputItemParamOfFunctionCall(args, block.ToolCallID, block.ToolName)) + } + return result + case PromptRoleToolResult: + text := strings.TrimSpace(msg.Text()) + if strings.TrimSpace(msg.ToolCallID) == "" || text == "" { + return nil + } + return responses.ResponseInputParam{buildFunctionCallOutputItem(msg.ToolCallID, text, false)} + default: + return nil + } +} + +func promptContextToChatCompletionMessages(ctx PromptContext, supportsVideoURL bool) []openai.ChatCompletionMessageParamUnion { + var messages []openai.ChatCompletionMessageParamUnion + if system := strings.TrimSpace(ctx.SystemPrompt); system != "" { + messages = append(messages, openai.SystemMessage(system)) + } + for _, msg := range ctx.Messages { + switch msg.Role { + case PromptRoleUser: + user := promptUserToChatMessage(msg) + if user != nil { + messages = append(messages, openai.ChatCompletionMessageParamUnion{OfUser: user}) + } + case PromptRoleAssistant: + assistant := promptAssistantToChatMessage(msg) + if assistant != nil { + messages = append(messages, openai.ChatCompletionMessageParamUnion{OfAssistant: assistant}) + } + case PromptRoleToolResult: + tool := promptToolToChatMessage(msg) + if tool != nil { + messages = append(messages, openai.ChatCompletionMessageParamUnion{OfTool: tool}) + } + } + } + return messages +} + +func promptUserToChatMessage(msg PromptMessage) *openai.ChatCompletionUserMessageParam { + var contentParts []openai.ChatCompletionContentPartUnionParam + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText: + text := strings.TrimSpace(block.Text) + if text == "" { + continue + } + contentParts = append(contentParts, openai.ChatCompletionContentPartUnionParam{ + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: text, + }, + }) + case PromptBlockImage: + imageURL := resolveBlockImageURL(block) + if imageURL == "" { + continue + } + contentParts = append(contentParts, openai.ChatCompletionContentPartUnionParam{ + OfImageURL: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: imageURL, + }, + }, + }) + } + } + if len(contentParts) == 0 { + return nil + } + return &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{OfArrayOfContentParts: contentParts}, + } +} + +func promptAssistantToChatMessage(msg PromptMessage) *openai.ChatCompletionAssistantMessageParam { + var contentParts []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion + var toolCalls []openai.ChatCompletionMessageToolCallUnionParam + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText: + text := strings.TrimSpace(block.Text) + if text == "" { + continue + } + contentParts = append(contentParts, openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{ + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: text, + }, + }) + case PromptBlockToolCall: + if strings.TrimSpace(block.ToolCallID) == "" || strings.TrimSpace(block.ToolName) == "" { + continue + } + args := strings.TrimSpace(block.ToolCallArguments) + if args == "" { + args = "{}" + } + toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallUnionParam{ + OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ + ID: block.ToolCallID, + Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ + Name: block.ToolName, + Arguments: args, + }, + }, + }) + } + } + if len(contentParts) == 0 && len(toolCalls) == 0 { + return nil + } + return &openai.ChatCompletionAssistantMessageParam{ + Content: openai.ChatCompletionAssistantMessageParamContentUnion{OfArrayOfContentParts: contentParts}, + ToolCalls: toolCalls, + } +} + +func promptToolToChatMessage(msg PromptMessage) *openai.ChatCompletionToolMessageParam { + text := strings.TrimSpace(msg.Text()) + if strings.TrimSpace(msg.ToolCallID) == "" || text == "" { + return nil + } + return &openai.ChatCompletionToolMessageParam{ + ToolCallID: msg.ToolCallID, + Content: openai.ChatCompletionToolMessageParamContentUnion{ + OfString: openai.String(text), + }, + } +} + +func chatMessagesToPromptContext(messages []openai.ChatCompletionMessageParamUnion) PromptContext { + var ctx PromptContext + for _, msg := range messages { + appendChatMessageToPromptContext(&ctx, msg) + } + return ctx +} + +func appendChatMessageToPromptContext(ctx *PromptContext, msg openai.ChatCompletionMessageParamUnion) { + if ctx == nil { + return + } + switch { + case msg.OfSystem != nil: + AppendPromptText(&ctx.SystemPrompt, extractChatSystemText(msg.OfSystem.Content)) + case msg.OfUser != nil: + ctx.Messages = append(ctx.Messages, promptMessageFromChatUser(msg.OfUser)) + case msg.OfAssistant != nil: + ctx.Messages = append(ctx.Messages, promptMessageFromChatAssistant(msg.OfAssistant)) + case msg.OfTool != nil: + ctx.Messages = append(ctx.Messages, promptMessageFromChatTool(msg.OfTool)) + } +} + +func extractChatSystemText(content openai.ChatCompletionSystemMessageParamContentUnion) string { + if content.OfString.Value != "" { + return content.OfString.Value + } + var values []string + for _, part := range content.OfArrayOfContentParts { + if text := strings.TrimSpace(part.Text); text != "" { + values = append(values, text) + } + } + return strings.Join(values, "\n") +} + +func promptMessageFromChatUser(msg *openai.ChatCompletionUserMessageParam) PromptMessage { + pm := PromptMessage{Role: PromptRoleUser} + if msg == nil { + return pm + } + if msg.Content.OfString.Value != "" { + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: msg.Content.OfString.Value}) + } + for _, part := range msg.Content.OfArrayOfContentParts { + switch { + case part.OfText != nil: + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: part.OfText.Text}) + case part.OfImageURL != nil: + pm.Blocks = append(pm.Blocks, PromptBlock{ + Type: PromptBlockImage, + ImageURL: part.OfImageURL.ImageURL.URL, + MimeType: inferPromptMimeTypeFromDataURL(part.OfImageURL.ImageURL.URL), + }) + } + } + return pm +} + +func promptMessageFromChatAssistant(msg *openai.ChatCompletionAssistantMessageParam) PromptMessage { + pm := PromptMessage{Role: PromptRoleAssistant} + if msg == nil { + return pm + } + if msg.Content.OfString.Value != "" { + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: msg.Content.OfString.Value}) + } + for _, part := range msg.Content.OfArrayOfContentParts { + if part.OfText == nil { + continue + } + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: part.OfText.Text}) + } + for _, toolCall := range msg.ToolCalls { + if toolCall.OfFunction == nil { + continue + } + pm.Blocks = append(pm.Blocks, PromptBlock{ + Type: PromptBlockToolCall, + ToolCallID: toolCall.OfFunction.ID, + ToolName: toolCall.OfFunction.Function.Name, + ToolCallArguments: toolCall.OfFunction.Function.Arguments, + }) + } + return pm +} + +func promptMessageFromChatTool(msg *openai.ChatCompletionToolMessageParam) PromptMessage { + pm := PromptMessage{Role: PromptRoleToolResult} + if msg == nil { + return pm + } + pm.ToolCallID = msg.ToolCallID + if msg.Content.OfString.Value != "" { + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: msg.Content.OfString.Value}) + } + for _, part := range msg.Content.OfArrayOfContentParts { + if strings.TrimSpace(part.Text) == "" { + continue + } + pm.Blocks = append(pm.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) + } + return pm +} + +func inferPromptMimeTypeFromDataURL(value string) string { + value = strings.TrimSpace(value) + rest, ok := strings.CutPrefix(value, "data:") + if !ok { + return "" + } + idx := strings.Index(rest, ";") + if idx <= 0 { + return "" + } + return rest[:idx] +} + +func hasUnsupportedResponsesPromptContext(ctx PromptContext) bool { + for _, msg := range ctx.Messages { + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText, PromptBlockImage, PromptBlockThinking, PromptBlockToolCall: + default: + return true + } + } + } + return false +} diff --git a/bridges/ai/prompt_context_local_test.go b/bridges/ai/prompt_context_local_test.go new file mode 100644 index 00000000..0d6814ea --- /dev/null +++ b/bridges/ai/prompt_context_local_test.go @@ -0,0 +1,23 @@ +package ai + +import "testing" + +func TestPromptAssistantToChatMessageNormalizesBlankToolArguments(t *testing.T) { + msg := PromptMessage{ + Role: PromptRoleAssistant, + Blocks: []PromptBlock{{ + Type: PromptBlockToolCall, + ToolCallID: "call_123", + ToolName: "search", + ToolCallArguments: " ", + }}, + } + + assistant := promptAssistantToChatMessage(msg) + if assistant == nil || len(assistant.ToolCalls) != 1 || assistant.ToolCalls[0].OfFunction == nil { + t.Fatalf("expected one function tool call, got %#v", assistant) + } + if got := assistant.ToolCalls[0].OfFunction.Function.Arguments; got != "{}" { + t.Fatalf("expected blank tool arguments to normalize to {}, got %q", got) + } +} diff --git a/bridges/ai/prompt_context_ops.go b/bridges/ai/prompt_context_ops.go new file mode 100644 index 00000000..2a6a1933 --- /dev/null +++ b/bridges/ai/prompt_context_ops.go @@ -0,0 +1,45 @@ +package ai + +import "strings" + +func ClonePromptMessages(messages []PromptMessage) []PromptMessage { + if len(messages) == 0 { + return nil + } + out := make([]PromptMessage, 0, len(messages)) + for _, message := range messages { + cloned := message + if len(message.Blocks) > 0 { + cloned.Blocks = append([]PromptBlock(nil), message.Blocks...) + } + out = append(out, cloned) + } + return out +} + +func ClonePromptContext(ctx PromptContext) PromptContext { + cloned := ctx + cloned.Messages = ClonePromptMessages(ctx.Messages) + if len(ctx.Tools) > 0 { + cloned.Tools = append([]ToolDefinition(nil), ctx.Tools...) + } + return cloned +} + +func PromptContextMessageCount(ctx PromptContext) int { + count := len(ctx.Messages) + if strings.TrimSpace(ctx.SystemPrompt) != "" { + count++ + } + return count +} + +func newUserTextPromptMessage(text string) PromptMessage { + return PromptMessage{ + Role: PromptRoleUser, + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: text, + }}, + } +} diff --git a/bridges/ai/prompt_history_test.go b/bridges/ai/prompt_history_test.go new file mode 100644 index 00000000..372c591a --- /dev/null +++ b/bridges/ai/prompt_history_test.go @@ -0,0 +1,32 @@ +package ai + +import "testing" + +func TestFilterPromptBlocksForHistoryDropsThinking(t *testing.T) { + filtered := filterPromptBlocksForHistory([]PromptBlock{ + {Type: PromptBlockThinking, Text: "internal analysis"}, + {Type: PromptBlockText, Text: "visible reply"}, + }, false) + if len(filtered) != 1 { + t.Fatalf("expected 1 block after filtering, got %d", len(filtered)) + } + if filtered[0].Type != PromptBlockText || filtered[0].Text != "visible reply" { + t.Fatalf("unexpected filtered blocks: %#v", filtered) + } +} + +func TestPromptMessageVisibleTextOmitsThinking(t *testing.T) { + msg := PromptMessage{ + Role: PromptRoleAssistant, + Blocks: []PromptBlock{ + {Type: PromptBlockThinking, Text: "internal analysis"}, + {Type: PromptBlockText, Text: "visible reply"}, + }, + } + if got := msg.VisibleText(); got != "visible reply" { + t.Fatalf("expected visible text only, got %q", got) + } + if got := msg.Text(); got != "internal analysis\nvisible reply" { + t.Fatalf("expected full text to retain thinking, got %q", got) + } +} diff --git a/bridges/ai/prompt_projection_local.go b/bridges/ai/prompt_projection_local.go new file mode 100644 index 00000000..8dee4c39 --- /dev/null +++ b/bridges/ai/prompt_projection_local.go @@ -0,0 +1,196 @@ +package ai + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/beeper/agentremote/sdk" +) + +func promptMessagesFromTurnData(td sdk.TurnData) []PromptMessage { + if td.Role == "" { + return nil + } + switch td.Role { + case "user": + msg := PromptMessage{Role: PromptRoleUser} + for _, part := range td.Parts { + switch normalizePromptTurnPartType(part.Type) { + case "text": + if strings.TrimSpace(part.Text) != "" { + msg.Blocks = append(msg.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) + } + case "image": + imageB64 := promptExtraString(part.Extra, "imageB64") + if strings.TrimSpace(part.URL) == "" && imageB64 == "" { + continue + } + msg.Blocks = append(msg.Blocks, PromptBlock{ + Type: PromptBlockImage, + ImageURL: part.URL, + ImageB64: imageB64, + MimeType: part.MediaType, + }) + } + } + if len(msg.Blocks) == 0 { + return nil + } + return []PromptMessage{msg} + case "assistant": + assistant := PromptMessage{Role: PromptRoleAssistant} + var results []PromptMessage + for _, part := range td.Parts { + switch normalizePromptTurnPartType(part.Type) { + case "text": + if strings.TrimSpace(part.Text) != "" { + assistant.Blocks = append(assistant.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) + } + case "reasoning": + text := strings.TrimSpace(part.Reasoning) + if text == "" { + text = strings.TrimSpace(part.Text) + } + if text != "" { + assistant.Blocks = append(assistant.Blocks, PromptBlock{Type: PromptBlockThinking, Text: text}) + } + case "tool": + if strings.TrimSpace(part.ToolCallID) != "" && strings.TrimSpace(part.ToolName) != "" { + assistant.Blocks = append(assistant.Blocks, PromptBlock{ + Type: PromptBlockToolCall, + ToolCallID: part.ToolCallID, + ToolName: part.ToolName, + ToolCallArguments: canonicalPromptToolArguments(part.Input), + }) + } + outputText := strings.TrimSpace(formatPromptCanonicalValue(part.Output)) + if outputText == "" { + outputText = strings.TrimSpace(part.ErrorText) + } + if outputText == "" && part.State == "output-denied" { + outputText = "Denied by user" + } + if strings.TrimSpace(part.ToolCallID) != "" && outputText != "" { + results = append(results, PromptMessage{ + Role: PromptRoleToolResult, + ToolCallID: part.ToolCallID, + ToolName: part.ToolName, + IsError: strings.TrimSpace(part.ErrorText) != "", + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: outputText, + }}, + }) + } + } + } + if len(assistant.Blocks) == 0 && len(results) == 0 { + return nil + } + out := make([]PromptMessage, 0, 1+len(results)) + if len(assistant.Blocks) > 0 { + out = append(out, assistant) + } + return append(out, results...) + default: + return nil + } +} + +// turnDataFromUserPromptMessages intentionally projects only the latest user +// message because callers pass a single-message tail via promptTail(..., 1). +func turnDataFromUserPromptMessages(messages []PromptMessage) (sdk.TurnData, bool) { + if len(messages) == 0 { + return sdk.TurnData{}, false + } + msg := messages[0] + if msg.Role != PromptRoleUser { + return sdk.TurnData{}, false + } + td := sdk.TurnData{Role: "user"} + td.Parts = make([]sdk.TurnPart, 0, len(msg.Blocks)) + for _, block := range msg.Blocks { + switch block.Type { + case PromptBlockText: + if strings.TrimSpace(block.Text) != "" { + td.Parts = append(td.Parts, sdk.TurnPart{Type: "text", Text: block.Text}) + } + case PromptBlockImage: + if strings.TrimSpace(block.ImageURL) == "" && strings.TrimSpace(block.ImageB64) == "" { + continue + } + part := sdk.TurnPart{Type: "image", URL: block.ImageURL, MediaType: block.MimeType} + if strings.TrimSpace(block.ImageB64) != "" { + part.Extra = map[string]any{"imageB64": block.ImageB64} + } + td.Parts = append(td.Parts, part) + } + } + return td, len(td.Parts) > 0 +} + +func promptExtraString(extra map[string]any, key string) string { + if len(extra) == 0 { + return "" + } + value, _ := extra[key].(string) + return value +} + +func normalizePromptTurnPartType(partType string) string { + if partType == "dynamic-tool" { + return "tool" + } + return partType +} + +func canonicalPromptToolArguments(raw any) string { + switch typed := raw.(type) { + case nil: + return "{}" + case string: + trimmed := strings.TrimSpace(typed) + if trimmed == "" { + return "{}" + } + var decoded any + if err := json.Unmarshal([]byte(trimmed), &decoded); err == nil { + data, marshalErr := json.Marshal(decoded) + if marshalErr == nil && string(data) != "null" { + return string(data) + } + } + data, err := json.Marshal(typed) + if err == nil && string(data) != "null" { + return string(data) + } + default: + if data, err := json.Marshal(typed); err == nil && string(data) != "null" { + return string(data) + } + } + if value := strings.TrimSpace(formatPromptCanonicalValue(raw)); value != "" { + data, err := json.Marshal(value) + if err == nil && string(data) != "null" { + return string(data) + } + return value + } + return "{}" +} + +func formatPromptCanonicalValue(raw any) string { + switch typed := raw.(type) { + case nil: + return "" + case string: + return typed + default: + data, err := json.Marshal(typed) + if err != nil { + return fmt.Sprint(typed) + } + return string(data) + } +} diff --git a/bridges/ai/prompt_projection_local_test.go b/bridges/ai/prompt_projection_local_test.go new file mode 100644 index 00000000..c7676e14 --- /dev/null +++ b/bridges/ai/prompt_projection_local_test.go @@ -0,0 +1,15 @@ +package ai + +import "testing" + +func TestCanonicalPromptToolArgumentsJSONEncodesPlainStrings(t *testing.T) { + if got := canonicalPromptToolArguments("hello"); got != `"hello"` { + t.Fatalf("expected plain string to be JSON-encoded, got %q", got) + } +} + +func TestCanonicalPromptToolArgumentsPreservesJSONStrings(t *testing.T) { + if got := canonicalPromptToolArguments(`{"query":"matrix"}`); got != `{"query":"matrix"}` { + t.Fatalf("expected JSON string to stay canonical JSON, got %q", got) + } +} diff --git a/bridges/ai/provider_openai_chat.go b/bridges/ai/provider_openai_chat.go deleted file mode 100644 index f1e7be01..00000000 --- a/bridges/ai/provider_openai_chat.go +++ /dev/null @@ -1,56 +0,0 @@ -package ai - -import ( - "context" - "errors" - "fmt" - - "github.com/openai/openai-go/v3" - - bridgesdk "github.com/beeper/agentremote/sdk" -) - -func (o *OpenAIProvider) generateChatCompletions(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { - chatMessages := bridgesdk.PromptContextToChatCompletionMessages(params.Context.PromptContext, isOpenRouterBaseURL(o.baseURL)) - if len(chatMessages) == 0 { - return nil, errors.New("no chat messages for completion") - } - - req := openai.ChatCompletionNewParams{ - Model: params.Model, - Messages: chatMessages, - } - if params.MaxCompletionTokens > 0 { - req.MaxCompletionTokens = openai.Int(int64(params.MaxCompletionTokens)) - } - if params.Temperature != nil { - req.Temperature = openai.Float(*params.Temperature) - } - if len(params.Context.Tools) > 0 { - req.Tools = ToOpenAIChatTools(params.Context.Tools, resolveToolStrictMode(isOpenRouterBaseURL(o.baseURL)), &o.log) - req.Tools = dedupeChatToolParams(req.Tools) - } - - resp, err := o.client.Chat.Completions.New(ctx, req) - if err != nil { - return nil, fmt.Errorf("OpenAI chat completion failed: %w", err) - } - - var content string - var finishReason string - if len(resp.Choices) > 0 { - content = resp.Choices[0].Message.Content - finishReason = resp.Choices[0].FinishReason - } - - return &GenerateResponse{ - Content: content, - FinishReason: finishReason, - Usage: UsageInfo{ - PromptTokens: int(resp.Usage.PromptTokens), - CompletionTokens: int(resp.Usage.CompletionTokens), - TotalTokens: int(resp.Usage.TotalTokens), - ReasoningTokens: int(resp.Usage.CompletionTokensDetails.ReasoningTokens), - }, - }, nil -} diff --git a/bridges/ai/provider_openai_responses.go b/bridges/ai/provider_openai_responses.go index 1e18e08e..c0f83dde 100644 --- a/bridges/ai/provider_openai_responses.go +++ b/bridges/ai/provider_openai_responses.go @@ -8,8 +8,6 @@ import ( "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" - - bridgesdk "github.com/beeper/agentremote/sdk" ) // reasoningEffortMap maps string effort levels to SDK constants. @@ -24,7 +22,7 @@ func (o *OpenAIProvider) buildResponsesParams(params GenerateParams) responses.R responsesParams := responses.ResponseNewParams{ Model: params.Model, Input: responses.ResponseNewParamsInputUnion{ - OfInputItemList: bridgesdk.PromptContextToResponsesInput(params.Context.PromptContext), + OfInputItemList: promptContextToResponsesInput(params.Context), }, } @@ -60,7 +58,7 @@ func (o *OpenAIProvider) buildResponsesParams(params GenerateParams) responses.R // GenerateStream generates a streaming response from OpenAI using the Responses API. func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GenerateParams) (<-chan StreamEvent, error) { - if bridgesdk.HasUnsupportedResponsesPromptContext(params.Context.PromptContext) { + if hasUnsupportedResponsesPromptContext(params.Context) { return nil, fmt.Errorf("responses API does not support prompt context block types required by this request") } @@ -150,7 +148,7 @@ func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GeneratePara // Generate performs a non-streaming generation using the Responses API. func (o *OpenAIProvider) Generate(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { - if bridgesdk.HasUnsupportedResponsesPromptContext(params.Context.PromptContext) { + if hasUnsupportedResponsesPromptContext(params.Context) { return nil, fmt.Errorf("responses API does not support prompt context block types required by this request") } diff --git a/bridges/ai/provider_openai_responses_test.go b/bridges/ai/provider_openai_responses_test.go index 70e02059..9db3e9c7 100644 --- a/bridges/ai/provider_openai_responses_test.go +++ b/bridges/ai/provider_openai_responses_test.go @@ -1,37 +1,19 @@ package ai import ( - "context" - "strings" "testing" "go.mau.fi/util/ptr" - - bridgesdk "github.com/beeper/agentremote/sdk" ) -func TestGenerateStreamRejectsUnsupportedResponsesPromptContext(t *testing.T) { +func TestBuildResponsesParamsAcceptsBridgePromptContext(t *testing.T) { provider := &OpenAIProvider{} - params := GenerateParams{ - Context: PromptContext{ - PromptContext: bridgesdk.UserPromptContext(bridgesdk.PromptBlock{ - Type: bridgesdk.PromptBlockAudio, - AudioB64: "YXVkaW8=", - AudioFormat: "mp3", - MimeType: "audio/mpeg", - }), - }, - } - - events, err := provider.GenerateStream(context.Background(), params) - if err == nil { - t.Fatal("expected unsupported prompt context error") - } - if events != nil { - t.Fatal("expected nil event channel on validation failure") - } - if !strings.Contains(err.Error(), "responses API does not support prompt context block types required by this request") { - t.Fatalf("unexpected error: %v", err) + params := provider.buildResponsesParams(GenerateParams{ + Model: "gpt-5.2", + Context: UserPromptContext(PromptBlock{Type: PromptBlockText, Text: "hello"}), + }) + if len(params.Input.OfInputItemList) != 1 { + t.Fatalf("expected one input item, got %d", len(params.Input.OfInputItemList)) } } diff --git a/bridges/ai/provisioning.go b/bridges/ai/provisioning.go index 6d3e5caf..9d458d0e 100644 --- a/bridges/ai/provisioning.go +++ b/bridges/ai/provisioning.go @@ -541,11 +541,15 @@ func resolveNamedMCPServer(client *AIClient, name string) (namedMCPServer, error } func ensureLoginMCPServer(meta *UserLoginMetadata) { - if meta.ServiceTokens == nil { - meta.ServiceTokens = &ServiceTokens{} + creds := ensureLoginCredentials(meta) + if creds == nil { + return + } + if creds.ServiceTokens == nil { + creds.ServiceTokens = &ServiceTokens{} } - if meta.ServiceTokens.MCPServers == nil { - meta.ServiceTokens.MCPServers = map[string]MCPServerConfig{} + if creds.ServiceTokens.MCPServers == nil { + creds.ServiceTokens.MCPServers = map[string]MCPServerConfig{} } } @@ -584,7 +588,12 @@ func (api *ProvisioningAPI) handleCreateMCPServer(w http.ResponseWriter, r *http } meta := loginMetadata(login) ensureLoginMCPServer(meta) - if _, exists := meta.ServiceTokens.MCPServers[name]; exists { + tokens := loginCredentialServiceTokens(meta) + if tokens == nil { + mautrix.MUnknown.WithMessage("Couldn't load MCP servers for this login.").Write(w) + return + } + if _, exists := tokens.MCPServers[name]; exists { mautrix.MInvalidParam.WithMessage("MCP server %s already exists.", name).Write(w) return } diff --git a/bridges/ai/queue_status_test.go b/bridges/ai/queue_status_test.go index 8be93e96..b30b6092 100644 --- a/bridges/ai/queue_status_test.go +++ b/bridges/ai/queue_status_test.go @@ -140,3 +140,53 @@ func TestDispatchOrQueueQueueAcceptReturnsPending(t *testing.T) { t.Fatalf("expected queue length 1 after accept, got %d", got) } } + +func TestDispatchOrQueueQueuesBehindExistingPendingWork(t *testing.T) { + roomID := id.RoomID("!room:example.com") + oc := &AIClient{ + activeRooms: map[id.RoomID]bool{}, + pendingQueues: map[id.RoomID]*pendingQueue{}, + } + oc.pendingQueues[roomID] = &pendingQueue{ + items: []pendingQueueItem{ + { + pending: pendingMessage{Type: pendingTypeText, MessageBody: "older"}, + }, + }, + cap: 10, + dropPolicy: airuntime.QueueDropOld, + } + + evt := &event.Event{ID: id.EventID("$new")} + portal := &bridgev2.Portal{Portal: &database.Portal{}} + portal.MXID = roomID + queueItem := pendingQueueItem{ + pending: pendingMessage{Type: pendingTypeText, MessageBody: "new"}, + messageID: string(evt.ID), + } + + _, isPending := oc.dispatchOrQueue( + context.Background(), + evt, + portal, + nil, + nil, + queueItem, + airuntime.QueueSettings{Mode: airuntime.QueueModeCollect, Cap: 10, DropPolicy: airuntime.QueueDropOld}, + PromptContext{}, + ) + + if !isPending { + t.Fatalf("expected pending=true when older queued work exists") + } + queue := oc.pendingQueues[roomID] + if queue == nil { + t.Fatalf("expected pending queue to exist") + } + if got := len(queue.items); got != 2 { + t.Fatalf("expected queue length 2 after enqueue behind backlog, got %d", got) + } + if oc.activeRooms[roomID] { + t.Fatalf("expected room to remain unacquired while backlog exists") + } +} diff --git a/bridges/ai/response_finalization_test.go b/bridges/ai/response_finalization_test.go index 83612a79..5f61bdaa 100644 --- a/bridges/ai/response_finalization_test.go +++ b/bridges/ai/response_finalization_test.go @@ -14,7 +14,7 @@ import ( ) func testStreamingState(turnID string) *streamingState { - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) turn := conv.StartTurn(context.Background(), nil, nil) turn.SetID(turnID) return &streamingState{ diff --git a/bridges/ai/response_retry.go b/bridges/ai/response_retry.go index 5731318f..a999ace5 100644 --- a/bridges/ai/response_retry.go +++ b/bridges/ai/response_retry.go @@ -5,24 +5,21 @@ import ( "errors" "fmt" "math" - "slices" + "strings" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" integrationruntime "github.com/beeper/agentremote/pkg/integrations/runtime" airuntime "github.com/beeper/agentremote/pkg/runtime" - bridgesdk "github.com/beeper/agentremote/sdk" ) const ( maxRetryAttempts = 3 // Maximum retry attempts for context length errors ) -// responseFunc is the signature for response handlers that can be retried on context length errors -type responseFunc func(ctx context.Context, evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, prompt []openai.ChatCompletionMessageParamUnion) (bool, *ContextLengthError, error) +type responseFuncCanonical func(ctx context.Context, evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, prompt PromptContext) (bool, *ContextLengthError, error) // responseWithRetry wraps a response function with context length retry logic. // It performs one runtime compaction retry attempt. @@ -31,19 +28,20 @@ func (oc *AIClient) responseWithRetry( evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, - responseFn responseFunc, + prompt PromptContext, + responseFn responseFuncCanonical, logLabel string, ) (bool, error) { - currentPrompt := prompt + currentPrompt := ClonePromptContext(prompt) preflightFlushAttempted := false overflowCompactionAttempts := 0 + cachedTokenEstimate := -1 var lastCLE *ContextLengthError for attempt := range maxRetryAttempts { if !preflightFlushAttempted { preflightFlushAttempted = true - oc.runCompactionPreflightFlushHook(ctx, portal, meta, currentPrompt, attempt+1) + cachedTokenEstimate = oc.runCompactionPreflightFlushHook(ctx, portal, meta, currentPrompt, attempt+1) } success, cle, err := responseFn(ctx, evt, portal, meta, currentPrompt) @@ -72,7 +70,11 @@ func (oc *AIClient) responseWithRetry( if meta != nil { modelID = oc.effectiveModel(meta) } - tokensBefore := estimatePromptTokensForModel(currentPrompt, modelID) + tokensBefore := cachedTokenEstimate + if tokensBefore < 0 { + tokensBefore = estimatePromptContextTokensForModel(currentPrompt, modelID) + } + cachedTokenEstimate = -1 // invalidate after use if overflowCompactionAttempts < maxRetryAttempts { overflowCompactionAttempts++ @@ -80,7 +82,6 @@ func (oc *AIClient) responseWithRetry( // Emit compaction start event. oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Phase: integrationruntime.CompactionLifecycleStart, @@ -88,19 +89,19 @@ func (oc *AIClient) responseWithRetry( ContextWindowTokens: contextWindow, RequestedTokens: cle.RequestedTokens, PromptTokens: tokensBefore, - MessagesBefore: len(currentPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), TokensBefore: tokensBefore, }) oc.emitCompactionStatus(ctx, portal, &CompactionEvent{ Type: CompactionEventStart, SessionID: sessionID, - MessagesBefore: len(currentPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), }) compacted, decision, compactionSuccess := oc.runtimeCompactOnOverflow(currentPrompt, contextWindow, cle.RequestedTokens, tokensBefore) - if compactionSuccess && len(compacted) > 2 { + if compactionSuccess && PromptContextMessageCount(compacted) > 2 { compacted = oc.applyCompactionModelSummaryAndRefresh(ctx, meta, currentPrompt, compacted, decision, contextWindow) - tokensAfter := estimatePromptTokensForModel(compacted, modelID) + tokensAfter := estimatePromptContextTokensForModel(compacted, modelID) if meta != nil { meta.CompactionCount++ oc.savePortalQuiet(ctx, portal, "compaction count") @@ -112,23 +113,22 @@ func (oc *AIClient) responseWithRetry( oc.emitCompactionStatus(ctx, portal, &CompactionEvent{ Type: CompactionEventEnd, SessionID: sessionID, - MessagesBefore: len(currentPrompt), - MessagesAfter: len(compacted), + MessagesBefore: PromptContextMessageCount(currentPrompt), + MessagesAfter: PromptContextMessageCount(compacted), TokensBefore: tokensBefore, TokensAfter: tokensAfter, Summary: summary, WillRetry: true, }) oc.emitCompactionLifecyclePhases(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Attempt: attempt + 1, ContextWindowTokens: contextWindow, RequestedTokens: cle.RequestedTokens, PromptTokens: tokensAfter, - MessagesBefore: len(currentPrompt), - MessagesAfter: len(compacted), + MessagesBefore: PromptContextMessageCount(currentPrompt), + MessagesAfter: PromptContextMessageCount(compacted), TokensBefore: tokensBefore, TokensAfter: tokensAfter, DroppedCount: decision.DroppedCount, @@ -137,8 +137,8 @@ func (oc *AIClient) responseWithRetry( }, integrationruntime.CompactionLifecycleEnd, integrationruntime.CompactionLifecycleRefresh) oc.loggerForContext(ctx).Info(). - Int("messages_before", len(currentPrompt)). - Int("messages_after", len(compacted)). + Int("messages_before", PromptContextMessageCount(currentPrompt)). + Int("messages_after", PromptContextMessageCount(compacted)). Int("tokens_before", tokensBefore). Int("tokens_after", tokensAfter). Int("dropped", decision.DroppedCount). @@ -150,19 +150,18 @@ func (oc *AIClient) responseWithRetry( // Compaction was insufficient. Try an explicit tool-result truncation pass. truncatedPrompt, truncatedCount := oc.truncateOversizedToolResultsForOverflow(currentPrompt, contextWindow) if truncatedCount > 0 { - tokensAfter := estimatePromptTokensForModel(truncatedPrompt, modelID) + tokensAfter := estimatePromptContextTokensForModel(truncatedPrompt, modelID) oc.emitCompactionStatus(ctx, portal, &CompactionEvent{ Type: CompactionEventEnd, SessionID: sessionID, - MessagesBefore: len(currentPrompt), - MessagesAfter: len(truncatedPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), + MessagesAfter: PromptContextMessageCount(truncatedPrompt), TokensBefore: tokensBefore, TokensAfter: tokensAfter, Summary: fmt.Sprintf("Truncated %d oversized tool result(s).", truncatedCount), WillRetry: true, }) oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Phase: integrationruntime.CompactionLifecycleEnd, @@ -170,8 +169,8 @@ func (oc *AIClient) responseWithRetry( ContextWindowTokens: contextWindow, RequestedTokens: cle.RequestedTokens, PromptTokens: tokensAfter, - MessagesBefore: len(currentPrompt), - MessagesAfter: len(truncatedPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), + MessagesAfter: PromptContextMessageCount(truncatedPrompt), TokensBefore: tokensBefore, TokensAfter: tokensAfter, Reason: "truncate_oversized_tool_results", @@ -192,7 +191,6 @@ func (oc *AIClient) responseWithRetry( Error: "compaction did not reduce context sufficiently", }) oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Phase: integrationruntime.CompactionLifecycleFail, @@ -200,7 +198,7 @@ func (oc *AIClient) responseWithRetry( ContextWindowTokens: contextWindow, RequestedTokens: cle.RequestedTokens, PromptTokens: tokensBefore, - MessagesBefore: len(currentPrompt), + MessagesBefore: PromptContextMessageCount(currentPrompt), TokensBefore: tokensBefore, Reason: "compaction did not reduce context sufficiently", Error: "compaction did not reduce context sufficiently", @@ -236,21 +234,20 @@ func (oc *AIClient) runCompactionPreflightFlushHook( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, attempt int, -) { +) int { if oc == nil || meta == nil { - return + return -1 } contextWindow := oc.getModelContextWindow(meta) if contextWindow <= 0 { contextWindow = 128000 } modelID := oc.effectiveModel(meta) - promptTokens := estimatePromptTokensForModel(prompt, modelID) + promptTokens := estimatePromptContextTokensForModel(prompt, modelID) projectedTokens := projectedCompactionFlushTokens(meta, promptTokens) oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, Portal: portal, Meta: meta, Phase: integrationruntime.CompactionLifecyclePreFlush, @@ -258,13 +255,14 @@ func (oc *AIClient) runCompactionPreflightFlushHook( ContextWindowTokens: contextWindow, RequestedTokens: projectedTokens, PromptTokens: promptTokens, - MessagesBefore: len(prompt), + MessagesBefore: PromptContextMessageCount(prompt), TokensBefore: promptTokens, }) oc.runCompactionFlushHook(ctx, portal, meta, prompt, &ContextLengthError{ RequestedTokens: projectedTokens, ModelMaxTokens: contextWindow, }, attempt) + return promptTokens } func projectedCompactionFlushTokens(meta *PortalMetadata, promptTokens int) int { @@ -314,7 +312,7 @@ func (oc *AIClient) runCompactionFlushHook( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, cle *ContextLengthError, attempt int, ) { @@ -340,10 +338,9 @@ func (oc *AIClient) runCompactionFlushHook( return } hook.OnContextOverflow(ctx, integrationruntime.ContextOverflowCall{ - Client: oc, Portal: portal, Meta: meta, - Prompt: prompt, + Prompt: promptContextToChatCompletionMessages(prompt, false), RequestedTokens: cle.RequestedTokens, ModelMaxTokens: cle.ModelMaxTokens, Attempt: attempt, @@ -357,9 +354,8 @@ func (oc *AIClient) runAgentLoopWithRetry( meta *PortalMetadata, promptContext PromptContext, ) { - prompt := oc.promptContextToDispatchMessages(ctx, portal, meta, promptContext) responseFn, logLabel := oc.selectAgentLoopRunFunc(meta, promptContext) - success, err := oc.responseWithRetry(ctx, evt, portal, meta, prompt, responseFn, logLabel) + success, err := oc.responseWithRetry(ctx, evt, portal, meta, promptContext, responseFn, logLabel) if success || err == nil { return } @@ -369,9 +365,9 @@ func (oc *AIClient) runAgentLoopWithRetry( oc.notifyMatrixSendFailure(ctx, portal, evt, err) } -func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext PromptContext) (responseFunc, string) { - if bridgesdk.HasUnsupportedResponsesPromptContext(promptContext.PromptContext) { - return oc.runChatCompletionsAgentLoop, "chat_completions" +func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext PromptContext) (responseFuncCanonical, string) { + if hasUnsupportedResponsesPromptContext(promptContext) { + return oc.runChatCompletionsAgentLoopPrompt, "chat_completions" } modelID := "" if oc != nil { @@ -380,13 +376,13 @@ func (oc *AIClient) selectAgentLoopRunFunc(meta *PortalMetadata, promptContext P switch oc.resolveModelAPI(meta) { case ModelAPIChatCompletions: if isDirectOpenAIModel(modelID) { - return func(context.Context, *event.Event, *bridgev2.Portal, *PortalMetadata, []openai.ChatCompletionMessageParamUnion) (bool, *ContextLengthError, error) { + return func(context.Context, *event.Event, *bridgev2.Portal, *PortalMetadata, PromptContext) (bool, *ContextLengthError, error) { return false, nil, fmt.Errorf("invalid model configuration: direct OpenAI model %q cannot use chat_completions", modelID) }, "invalid_model_api" } - return oc.runChatCompletionsAgentLoop, "chat_completions" + return oc.runChatCompletionsAgentLoopPrompt, "chat_completions" default: - return oc.runResponsesAgentLoop, "responses" + return oc.runResponsesAgentLoopPrompt, "responses" } } @@ -416,13 +412,14 @@ func (oc *AIClient) notifyContextLengthExceeded( } func (oc *AIClient) runtimeCompactOnOverflow( - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, contextWindowTokens int, requestedTokens int, currentPromptTokens int, -) ([]openai.ChatCompletionMessageParamUnion, airuntime.CompactionDecision, bool) { +) (PromptContext, airuntime.CompactionDecision, bool) { + serialized := promptContextToChatCompletionMessages(prompt, false) result := airuntime.CompactPromptOnOverflow(airuntime.OverflowCompactionInput{ - Prompt: prompt, + Prompt: serialized, ContextWindowTokens: contextWindowTokens, RequestedTokens: requestedTokens, CurrentPromptTokens: currentPromptTokens, @@ -435,14 +432,14 @@ func (oc *AIClient) runtimeCompactOnOverflow( MaxHistoryShare: oc.pruningMaxHistoryShare(), ProtectedTail: 3, }) - return result.Prompt, result.Decision, result.Success + return chatMessagesToPromptContext(result.Prompt), result.Decision, result.Success } func (oc *AIClient) truncateOversizedToolResultsForOverflow( - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, contextWindowTokens int, -) ([]openai.ChatCompletionMessageParamUnion, int) { - if len(prompt) == 0 { +) (PromptContext, int) { + if len(prompt.Messages) == 0 { return prompt, 0 } cfg := oc.pruningConfigOrDefault() @@ -461,13 +458,13 @@ func (oc *AIClient) truncateOversizedToolResultsForOverflow( } } - out := slices.Clone(prompt) + out := ClonePromptContext(prompt) truncated := 0 - for i, msg := range out { - if msg.OfTool == nil { + for i, msg := range out.Messages { + if msg.Role != PromptRoleToolResult { continue } - content := airuntime.ExtractToolContent(msg.OfTool.Content) + content := strings.TrimSpace(msg.Text()) if len(content) <= thresholdChars { continue } @@ -475,12 +472,52 @@ func (oc *AIClient) truncateOversizedToolResultsForOverflow( if trimmed == content { continue } - out[i] = openai.ToolMessage(trimmed, msg.OfTool.ToolCallID) + out.Messages[i].Blocks = rewriteTrimmedToolResultBlocks(msg.Blocks, trimmed) truncated++ } return out, truncated } +func rewriteTrimmedToolResultBlocks(blocks []PromptBlock, trimmed string) []PromptBlock { + if len(blocks) == 0 { + return []PromptBlock{{Type: PromptBlockText, Text: trimmed}} + } + remaining := trimmed + rewritten := make([]PromptBlock, 0, len(blocks)) + previousTextBlock := false + for _, block := range blocks { + if remaining == "" { + break + } + switch block.Type { + case PromptBlockText, PromptBlockThinking: + default: + continue + } + if block.Text == "" { + continue + } + if previousTextBlock && strings.HasPrefix(remaining, "\n") { + remaining = remaining[1:] + if remaining == "" { + break + } + } + take := len(block.Text) + if take > len(remaining) { + take = len(remaining) + } + block.Text = remaining[:take] + remaining = remaining[take:] + rewritten = append(rewritten, block) + previousTextBlock = true + } + if len(rewritten) == 0 { + return []PromptBlock{{Type: PromptBlockText, Text: trimmed}} + } + return rewritten +} + // emitCompactionStatus sends a compaction status event to the room func (oc *AIClient) emitCompactionStatus(ctx context.Context, portal *bridgev2.Portal, evt *CompactionEvent) { if portal == nil || portal.MXID == "" { diff --git a/bridges/ai/response_retry_test.go b/bridges/ai/response_retry_test.go index 080d807b..aa909070 100644 --- a/bridges/ai/response_retry_test.go +++ b/bridges/ai/response_retry_test.go @@ -1,6 +1,7 @@ package ai import ( + "strings" "testing" "github.com/rs/zerolog" @@ -23,7 +24,9 @@ func newPruningTestClient(pruning *airuntime.PruningConfig, provider string) *AI }, connector: &OpenAIConnector{ Config: Config{ - Pruning: pruning, + Agents: &AgentsConfig{Defaults: &AgentDefaultsConfig{ + Compaction: pruning, + }}, }, }, log: zerolog.Nop(), @@ -171,3 +174,37 @@ func TestPruningPostCompactionRefreshPrompt_Defaults(t *testing.T) { t.Fatal("expected non-empty post-compaction refresh prompt") } } + +func TestTruncateOversizedToolResultsForOverflowPreservesBlockStructure(t *testing.T) { + client := newPruningTestClient(&airuntime.PruningConfig{ + SoftTrimMaxChars: 20, + SoftTrimHeadChars: 8, + SoftTrimTailChars: 8, + }, ProviderOpenAI) + + prompt := PromptContext{ + Messages: []PromptMessage{{ + Role: PromptRoleToolResult, + Blocks: []PromptBlock{ + {Type: PromptBlockText, Text: strings.Repeat("first-block-", 12)}, + {Type: PromptBlockText, Text: strings.Repeat("second-block-", 12)}, + }, + }}, + } + + got, truncated := client.truncateOversizedToolResultsForOverflow(prompt, 0) + if truncated != 1 { + t.Fatalf("expected one truncated tool result, got %d", truncated) + } + if len(got.Messages) != 1 || len(got.Messages[0].Blocks) == 0 { + t.Fatalf("expected rewritten blocks, got %#v", got.Messages) + } + for _, block := range got.Messages[0].Blocks { + if block.Type != PromptBlockText { + t.Fatalf("expected text blocks to be preserved, got %#v", got.Messages[0].Blocks) + } + } + if got.Messages[0].Text() != airuntime.SoftTrimToolResult(prompt.Messages[0].Text(), client.pruningConfigOrDefault()) { + t.Fatalf("expected trimmed text to match soft-trim output, got %q", got.Messages[0].Text()) + } +} diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index f13b0a00..64071164 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -117,16 +117,6 @@ func (oc *AIClient) enqueueSteerQueue(roomID id.RoomID, item pendingQueueItem) b return true } -func (oc *AIClient) registerRoomRunPendingItem(roomID id.RoomID, item pendingQueueItem) { - run := oc.getRoomRun(roomID) - if run == nil { - return - } - run.mu.Lock() - defer run.mu.Unlock() - oc.registerRoomRunPendingItemLocked(run, item) -} - func (oc *AIClient) registerRoomRunPendingItemLocked(run *roomRunState, item pendingQueueItem) { if run == nil { return diff --git a/bridges/ai/runtime_compaction_adapter.go b/bridges/ai/runtime_compaction_adapter.go index 2ce1362b..165cb0ac 100644 --- a/bridges/ai/runtime_compaction_adapter.go +++ b/bridges/ai/runtime_compaction_adapter.go @@ -30,8 +30,10 @@ type CompactionEvent struct { } func (oc *AIClient) pruningConfigOrDefault() *airuntime.PruningConfig { - if oc != nil && oc.connector != nil && oc.connector.Config.Pruning != nil { - return airuntime.ApplyPruningDefaults(oc.connector.Config.Pruning) + if oc != nil && oc.connector != nil && oc.connector.Config.Agents != nil && + oc.connector.Config.Agents.Defaults != nil && + oc.connector.Config.Agents.Defaults.Compaction != nil { + return airuntime.ApplyPruningDefaults(oc.connector.Config.Agents.Defaults.Compaction) } return airuntime.DefaultPruningConfig() } @@ -161,3 +163,7 @@ func estimatePromptTokensForModel(prompt []openai.ChatCompletionMessageParamUnio } return estimatePromptTokensFallback(prompt) } + +func estimatePromptContextTokensForModel(prompt PromptContext, model string) int { + return estimatePromptTokensForModel(promptContextToChatCompletionMessages(prompt, false), model) +} diff --git a/bridges/ai/runtime_defaults_test.go b/bridges/ai/runtime_defaults_test.go index a06e32e9..a6acc25e 100644 --- a/bridges/ai/runtime_defaults_test.go +++ b/bridges/ai/runtime_defaults_test.go @@ -18,42 +18,44 @@ func TestApplyRuntimeDefaultsSetsPruningDefaults(t *testing.T) { if connector.Config.Bridge.CommandPrefix != "!ai" { t.Fatalf("expected command prefix !ai, got %q", connector.Config.Bridge.CommandPrefix) } - if connector.Config.Pruning == nil { - t.Fatal("expected pruning defaults to be initialized") + if connector.Config.Agents == nil || connector.Config.Agents.Defaults == nil || connector.Config.Agents.Defaults.Compaction == nil { + t.Fatal("expected compaction defaults to be initialized") } - if !connector.Config.Pruning.Enabled { + if !connector.Config.Agents.Defaults.Compaction.Enabled { t.Fatal("expected pruning defaults enabled") } - if connector.Config.Pruning.Mode != "cache-ttl" { - t.Fatalf("expected pruning mode cache-ttl, got %q", connector.Config.Pruning.Mode) + if connector.Config.Agents.Defaults.Compaction.Mode != "cache-ttl" { + t.Fatalf("expected pruning mode cache-ttl, got %q", connector.Config.Agents.Defaults.Compaction.Mode) } - if connector.Config.Pruning.TTL != time.Hour { - t.Fatalf("expected pruning ttl 1h, got %v", connector.Config.Pruning.TTL) + if connector.Config.Agents.Defaults.Compaction.TTL != time.Hour { + t.Fatalf("expected pruning ttl 1h, got %v", connector.Config.Agents.Defaults.Compaction.TTL) } } func TestApplyRuntimeDefaultsKeepsExplicitPruningModeOff(t *testing.T) { connector := &OpenAIConnector{ Config: Config{ - Pruning: &airuntime.PruningConfig{ - Mode: "off", - Enabled: false, - }, + Agents: &AgentsConfig{Defaults: &AgentDefaultsConfig{ + Compaction: &airuntime.PruningConfig{ + Mode: "off", + Enabled: false, + }, + }}, }, } connector.applyRuntimeDefaults() - if connector.Config.Pruning == nil { + if connector.Config.Agents == nil || connector.Config.Agents.Defaults == nil || connector.Config.Agents.Defaults.Compaction == nil { t.Fatal("expected pruning config to remain set") } - if connector.Config.Pruning.Mode != "off" { - t.Fatalf("expected pruning mode off to be preserved, got %q", connector.Config.Pruning.Mode) + if connector.Config.Agents.Defaults.Compaction.Mode != "off" { + t.Fatalf("expected pruning mode off to be preserved, got %q", connector.Config.Agents.Defaults.Compaction.Mode) } - if connector.Config.Pruning.Enabled { + if connector.Config.Agents.Defaults.Compaction.Enabled { t.Fatal("expected pruning enabled=false to be preserved") } - if connector.Config.Pruning.SoftTrimRatio <= 0 { + if connector.Config.Agents.Defaults.Compaction.SoftTrimRatio <= 0 { t.Fatal("expected missing pruning numeric defaults to be filled") } } diff --git a/bridges/ai/session_greeting.go b/bridges/ai/session_greeting.go index 5594d6df..f393dad9 100644 --- a/bridges/ai/session_greeting.go +++ b/bridges/ai/session_greeting.go @@ -5,7 +5,6 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" ) @@ -13,25 +12,24 @@ import ( const sessionGreetingPrompt = "A new session was started via !ai reset. Greet the user in your configured persona, if one is provided. Be yourself - use your defined voice, mannerisms, and mood. Keep it to 1-3 sentences and ask what they want to do. If the runtime model differs from default_model in the system prompt, mention the default model. Do not mention internal steps, files, tools, or reasoning." const autoGreetingPrompt = "A new chat was created. Greet the user in your configured persona, if one is provided. Be yourself - use your defined voice, mannerisms, and mood. Keep it to 1-3 sentences and ask what they want to do. If the runtime model differs from default_model in the system prompt, mention the default model. Do not mention internal steps, files, tools, or reasoning." -func maybePrependSessionGreeting( +func sessionGreetingFragment( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, log zerolog.Logger, -) []openai.ChatCompletionMessageParamUnion { +) string { if meta == nil { - return prompt + return "" } agentID := strings.TrimSpace(resolveAgentID(meta)) if agentID == "" { - return prompt + return "" } if meta.SessionBootstrapByAgent == nil { meta.SessionBootstrapByAgent = make(map[string]int64) } if meta.SessionBootstrapByAgent[agentID] != 0 { - return prompt + return "" } meta.SessionBootstrapByAgent[agentID] = time.Now().UnixMilli() if portal != nil { @@ -39,6 +37,5 @@ func maybePrependSessionGreeting( log.Warn().Err(err).Msg("Failed to persist session bootstrap state") } } - greeting := openai.SystemMessage(sessionGreetingPrompt) - return append([]openai.ChatCompletionMessageParamUnion{greeting}, prompt...) + return sessionGreetingPrompt } diff --git a/bridges/ai/session_greeting_test.go b/bridges/ai/session_greeting_test.go index 4d801c00..15ecbad2 100644 --- a/bridges/ai/session_greeting_test.go +++ b/bridges/ai/session_greeting_test.go @@ -4,31 +4,23 @@ import ( "context" "testing" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" ) -func TestMaybePrependSessionGreeting(t *testing.T) { +func TestSessionGreetingFragment(t *testing.T) { ctx := context.Background() meta := agentModeTestMeta("beeper") - prompt := []openai.ChatCompletionMessageParamUnion{} - out := maybePrependSessionGreeting(ctx, nil, meta, prompt, zerolog.Nop()) - if len(out) != 1 { - t.Fatalf("expected 1 greeting message, got %d", len(out)) + greeting := sessionGreetingFragment(ctx, nil, meta, zerolog.Nop()) + if greeting != sessionGreetingPrompt { + t.Fatalf("expected greeting prompt, got %q", greeting) } if meta.SessionBootstrapByAgent == nil || meta.SessionBootstrapByAgent["beeper"] == 0 { t.Fatal("expected SessionBootstrapByAgent to be set") } - if out[0].OfSystem == nil { - t.Fatal("expected system message") - } - if out[0].OfSystem.Content.OfString.Value != sessionGreetingPrompt { - t.Fatalf("unexpected greeting content: %q", out[0].OfSystem.Content.OfString.Value) - } - out2 := maybePrependSessionGreeting(ctx, nil, meta, []openai.ChatCompletionMessageParamUnion{}, zerolog.Nop()) - if len(out2) != 0 { - t.Fatalf("expected no additional greeting, got %d", len(out2)) + greeting2 := sessionGreetingFragment(ctx, nil, meta, zerolog.Nop()) + if greeting2 != "" { + t.Fatalf("expected no additional greeting, got %q", greeting2) } } diff --git a/bridges/ai/status_text.go b/bridges/ai/status_text.go index 10f00307..4b9a747b 100644 --- a/bridges/ai/status_text.go +++ b/bridges/ai/status_text.go @@ -224,17 +224,12 @@ func (oc *AIClient) estimatePromptTokens(ctx context.Context, portal *bridgev2.P if oc == nil || portal == nil { return 0 } - prompt, err := oc.buildBasePrompt(ctx, portal, meta) + promptContext, err := oc.buildBaseContext(ctx, portal, meta) if err != nil { return 0 } - prompt = oc.augmentPromptWithIntegrations(ctx, portal, meta, prompt) modelID := oc.effectiveModel(meta) - count, err := EstimateTokens(prompt, modelID) - if err != nil { - return 0 - } - return count + return estimatePromptContextTokensForModel(promptContext, modelID) } func (oc *AIClient) getSessionEntryMaybe(ctx context.Context, agentID, sessionKey string) *sessionEntry { diff --git a/bridges/ai/streaming_chat_completions.go b/bridges/ai/streaming_chat_completions.go index d2881188..4f1c3ea6 100644 --- a/bridges/ai/streaming_chat_completions.go +++ b/bridges/ai/streaming_chat_completions.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/packages/param" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" @@ -32,7 +31,7 @@ func (a *chatCompletionsTurnAdapter) handleStreamStepError( if cle := ParseContextLengthError(stepErr); cle != nil { return cle, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, a.state, a.meta, "context-length", stepErr) } - logChatCompletionsFailure(a.log, stepErr, params, a.meta, currentMessages, "stream_err") + logChatCompletionsFailure(a.log, stepErr, params, a.meta, a.prompt, "stream_err") return nil, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, a.state, a.meta, "error", stepErr) } @@ -49,14 +48,14 @@ func (a *chatCompletionsTurnAdapter) RunAgentTurn( typingSignals := a.typingSignals touchTyping := a.touchTyping isHeartbeat := a.isHeartbeat - currentMessages := a.messages + currentMessages := promptContextToChatCompletionMessages(a.prompt, oc.isOpenRouterProvider()) params := oc.buildChatCompletionsAgentLoopParams(ctx, meta, currentMessages) stream := oc.api.Chat.Completions.NewStreaming(ctx, params) if stream == nil { initErr := errors.New("chat completions streaming not available") - logChatCompletionsFailure(log, initErr, params, meta, currentMessages, "stream_init") + logChatCompletionsFailure(log, initErr, params, meta, a.prompt, "stream_init") return false, nil, oc.finishStreamingWithFailure(ctx, log, portal, state, meta, "error", initErr) } @@ -139,33 +138,60 @@ func (a *chatCompletionsTurnAdapter) RunAgentTurn( if shouldContinueChatToolLoop(state.finishReason, len(toolCallParams)) { state.needsTextSeparator = true - assistantMsg := openai.ChatCompletionAssistantMessageParam{ - ToolCalls: toolCallParams, + assistantMsg := PromptMessage{ + Role: PromptRoleAssistant, } if content := strings.TrimSpace(roundContent.String()); content != "" { - assistantMsg.Content.OfString = param.NewOpt(content) + assistantMsg.Blocks = append(assistantMsg.Blocks, PromptBlock{ + Type: PromptBlockText, + Text: content, + }) + } + for _, toolCall := range toolCallParams { + if toolCall.OfFunction == nil { + continue + } + assistantMsg.Blocks = append(assistantMsg.Blocks, PromptBlock{ + Type: PromptBlockToolCall, + ToolCallID: toolCall.OfFunction.ID, + ToolName: toolCall.OfFunction.Function.Name, + ToolCallArguments: toolCall.OfFunction.Function.Arguments, + }) + } + if len(assistantMsg.Blocks) > 0 { + a.prompt.Messages = append(a.prompt.Messages, assistantMsg) } - currentMessages = append(currentMessages, openai.ChatCompletionMessageParamUnion{OfAssistant: &assistantMsg}) for _, output := range state.pendingFunctionOutputs { - currentMessages = append(currentMessages, openai.ToolMessage(output.output, output.callID)) + a.prompt.Messages = append(a.prompt.Messages, PromptMessage{ + Role: PromptRoleToolResult, + ToolCallID: output.callID, + ToolName: output.name, + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: output.output, + }}, + }) } - currentMessages = append(currentMessages, buildSteeringUserMessages(steeringPrompts)...) + a.prompt.Messages = append(a.prompt.Messages, buildSteeringPromptMessages(steeringPrompts)...) if round >= maxAgentLoopToolTurns { log.Warn().Int("rounds", round+1).Msg("Max tool call rounds reached; stopping chat completions continuation") - currentMessages = append(currentMessages, openai.AssistantMessage("Continuation stopped after reaching the maximum number of streaming tool rounds.")) + a.prompt.Messages = append(a.prompt.Messages, PromptMessage{ + Role: PromptRoleAssistant, + Blocks: []PromptBlock{{ + Type: PromptBlockText, + Text: "Continuation stopped after reaching the maximum number of streaming tool rounds.", + }}, + }) state.clearContinuationState() - a.messages = currentMessages return false, nil, nil } // Chat Completions does not support MCP approvals; clearContinuationState // is safe here — it resets pendingFunctionOutputs (consumed above) and // pendingMcpApprovals (always empty for Chat). state.clearContinuationState() - a.messages = currentMessages return true, nil, nil } - a.messages = currentMessages return false, nil, nil } @@ -189,12 +215,12 @@ func (a *chatCompletionsTurnAdapter) FinalizeAgentLoop(ctx context.Context) { } -func (oc *AIClient) runChatCompletionsAgentLoop( +func (oc *AIClient) runChatCompletionsAgentLoopPrompt( ctx context.Context, evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) (bool, *ContextLengthError, error) { portalID := "" if portal != nil { @@ -205,13 +231,9 @@ func (oc *AIClient) runChatCompletionsAgentLoop( Str("portal", portalID). Logger() - return oc.runAgentLoop(ctx, log, evt, portal, meta, messages, func(prep streamingRunPrep, pruned []openai.ChatCompletionMessageParamUnion) agentLoopProvider { + return oc.runAgentLoop(ctx, log, evt, portal, meta, prompt, func(prep streamingRunPrep, prompt PromptContext) agentLoopProvider { return &chatCompletionsTurnAdapter{ - agentLoopProviderBase: newAgentLoopProviderBase(oc, log, portal, meta, prep, pruned), + agentLoopProviderBase: newAgentLoopProviderBase(oc, log, portal, meta, prep, prompt), } }) } - -// convertToResponsesInput converts Chat Completion messages to Responses API input items -// Supports native multimodal content: images (ResponseInputImageParam), files/PDFs (ResponseInputFileParam) -// Note: Audio is handled via Chat Completions API fallback (SDK v3.16.0 lacks Responses API audio union support) diff --git a/bridges/ai/streaming_continuation.go b/bridges/ai/streaming_continuation.go index ea2a4ec5..0c74d1b2 100644 --- a/bridges/ai/streaming_continuation.go +++ b/bridges/ai/streaming_continuation.go @@ -4,7 +4,6 @@ import ( "context" "strings" - "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" ) @@ -12,16 +11,15 @@ import ( // and/or after responding to tool approval requests. func (oc *AIClient) buildContinuationParams( ctx context.Context, + prompt *PromptContext, state *streamingState, meta *PortalMetadata, pendingOutputs []functionCallOutput, approvalInputs []responses.ResponseInputItemUnionParam, ) responses.ResponseNewParams { - // Build function call outputs as input var input responses.ResponseInputParam - if len(state.baseInput) > 0 { - // All Responses continuations are stateless: include the accumulated local history. - input = append(input, state.baseInput...) + if prompt != nil { + input = append(input, promptContextToResponsesInput(*prompt)...) } input = append(input, approvalInputs...) for _, output := range pendingOutputs { @@ -39,13 +37,20 @@ func (oc *AIClient) buildContinuationParams( steerPrompts = oc.getSteeringMessages(state.roomID) } if len(steerPrompts) > 0 { + steeringMessages := buildSteeringPromptMessages(steerPrompts) + if prompt != nil && len(steeringMessages) > 0 { + prompt.Messages = append(prompt.Messages, steeringMessages...) + } steerInput := oc.buildSteeringInputItems(steerPrompts, meta) if len(steerInput) > 0 { input = append(input, steerInput...) - state.baseInput = append(state.baseInput, steerInput...) } } - return oc.buildResponsesAgentLoopParams(ctx, meta, input, true) + systemPrompt := "" + if prompt != nil { + systemPrompt = prompt.SystemPrompt + } + return oc.buildResponsesAgentLoopParams(ctx, meta, systemPrompt, input, true) } func (oc *AIClient) buildSteeringInputItems(prompts []string, meta *PortalMetadata) responses.ResponseInputParam { @@ -58,8 +63,9 @@ func (oc *AIClient) buildSteeringInputItems(prompts []string, meta *PortalMetada if prompt == "" { continue } - messages := []openai.ChatCompletionMessageParamUnion{openai.UserMessage(prompt)} - input = append(input, oc.convertToResponsesInput(messages, meta)...) + input = append(input, promptContextToResponsesInput(UserPromptContext( + PromptBlock{Type: PromptBlockText, Text: prompt}, + ))...) } return input } diff --git a/bridges/ai/streaming_error_handling_test.go b/bridges/ai/streaming_error_handling_test.go index b981929f..6ae8fbfa 100644 --- a/bridges/ai/streaming_error_handling_test.go +++ b/bridges/ai/streaming_error_handling_test.go @@ -14,7 +14,7 @@ import ( func newTestStreamingStateWithTurn() *streamingState { state := newStreamingState(context.Background(), nil, "") - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) state.turn = conv.StartTurn(context.Background(), nil, nil) return state } diff --git a/bridges/ai/streaming_executor.go b/bridges/ai/streaming_executor.go index 6a4811eb..303d2956 100644 --- a/bridges/ai/streaming_executor.go +++ b/bridges/ai/streaming_executor.go @@ -3,7 +3,6 @@ package ai import ( "context" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" @@ -14,8 +13,6 @@ import ( type agentLoopProvider interface { TrackRoomRunStreaming() bool RunAgentTurn(ctx context.Context, evt *event.Event, round int) (continueLoop bool, cle *ContextLengthError, err error) - GetFollowUpMessages(ctx context.Context) []openai.ChatCompletionMessageParamUnion - ContinueAgentLoop(messages []openai.ChatCompletionMessageParamUnion) FinalizeAgentLoop(ctx context.Context) } @@ -28,7 +25,7 @@ type agentLoopProviderBase struct { typingSignals *TypingSignaler touchTyping func() isHeartbeat bool - messages []openai.ChatCompletionMessageParamUnion + prompt PromptContext } func newAgentLoopProviderBase( @@ -37,7 +34,7 @@ func newAgentLoopProviderBase( portal *bridgev2.Portal, meta *PortalMetadata, prep streamingRunPrep, - messages []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) agentLoopProviderBase { return agentLoopProviderBase{ oc: oc, @@ -48,38 +45,24 @@ func newAgentLoopProviderBase( typingSignals: prep.TypingSignals, touchTyping: prep.TouchTyping, isHeartbeat: prep.IsHeartbeat, - messages: messages, + prompt: prompt, } } -func (a *agentLoopProviderBase) GetFollowUpMessages(context.Context) []openai.ChatCompletionMessageParamUnion { - if a == nil || a.oc == nil || a.state == nil { - return nil - } - return a.oc.getFollowUpMessages(a.state.roomID) -} - -func (a *agentLoopProviderBase) ContinueAgentLoop(messages []openai.ChatCompletionMessageParamUnion) { - if a == nil || len(messages) == 0 { - return - } - a.messages = append(a.messages, messages...) -} - func (oc *AIClient) runAgentLoop( ctx context.Context, log zerolog.Logger, evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, - newProvider func(prep streamingRunPrep, pruned []openai.ChatCompletionMessageParamUnion) agentLoopProvider, + prompt PromptContext, + newProvider func(prep streamingRunPrep, prompt PromptContext) agentLoopProvider, ) (bool, *ContextLengthError, error) { - prep, pruned, typingCleanup := oc.prepareStreamingRun(ctx, log, evt, portal, meta, messages) + prep, typingCleanup := oc.prepareStreamingRun(ctx, log, evt, portal, meta) defer typingCleanup() state := prep.State - provider := newProvider(prep, pruned) + provider := newProvider(prep, prompt) if state.roomID != "" { if provider.TrackRoomRunStreaming() { oc.markRoomRunStreaming(state.roomID, true) @@ -106,12 +89,8 @@ func executeAgentLoopRounds( continue } - followUpMessages := provider.GetFollowUpMessages(ctx) - if len(followUpMessages) > 0 { - provider.ContinueAgentLoop(followUpMessages) - continue - } - + // Queued user messages are dispatched after room release via processPendingQueue. + // Finalize this turn immediately so later prompts cannot reopen it with more edits. finalizeAgentLoopExit(ctx, provider, false) return true, nil, nil } diff --git a/bridges/ai/streaming_init.go b/bridges/ai/streaming_init.go index dc66cf84..1833ec07 100644 --- a/bridges/ai/streaming_init.go +++ b/bridges/ai/streaming_init.go @@ -3,7 +3,6 @@ package ai import ( "context" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" @@ -24,7 +23,7 @@ func (oc *AIClient) createStreamingTurn( sourceEventID id.EventID, senderID string, ) *bridgesdk.Turn { - var sdkConfig *bridgesdk.Config + var sdkConfig *bridgesdk.Config[*AIClient, *Config] if oc.connector != nil { sdkConfig = oc.connector.sdkConfig } @@ -103,8 +102,7 @@ func (oc *AIClient) prepareStreamingRun( evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, -) (prep streamingRunPrep, pruned []openai.ChatCompletionMessageParamUnion, cleanup func()) { +) (prep streamingRunPrep, cleanup func()) { var sourceEventID id.EventID senderID := "" if evt != nil { @@ -176,13 +174,11 @@ func (oc *AIClient) prepareStreamingRun( } } - pruned = messages - prep = streamingRunPrep{ State: state, TypingSignals: typingSignals, TouchTyping: touchTyping, IsHeartbeat: isHeartbeat, } - return prep, pruned, cleanup + return prep, cleanup } diff --git a/bridges/ai/streaming_init_test.go b/bridges/ai/streaming_init_test.go index dac03c14..34cb5d67 100644 --- a/bridges/ai/streaming_init_test.go +++ b/bridges/ai/streaming_init_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - "github.com/openai/openai-go/v3" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" @@ -29,13 +28,12 @@ func TestPrepareStreamingRun_ModelRoomKeepsReplyTarget(t *testing.T) { }, } - prep, _, cleanup := oc.prepareStreamingRun( + prep, cleanup := oc.prepareStreamingRun( context.Background(), zerolog.Nop(), evt, nil, meta, - []openai.ChatCompletionMessageParamUnion{}, ) defer cleanup() @@ -64,13 +62,12 @@ func TestPrepareStreamingRun_AgentRoomKeepsReplyTarget(t *testing.T) { }, } - prep, _, cleanup := oc.prepareStreamingRun( + prep, cleanup := oc.prepareStreamingRun( context.Background(), zerolog.Nop(), evt, nil, meta, - []openai.ChatCompletionMessageParamUnion{}, ) defer cleanup() @@ -94,13 +91,12 @@ func TestPrepareStreamingRun_SnapshotsResponderFields(t *testing.T) { } meta := modelModeTestMeta("openai/gpt-5.2") - prep, _, cleanup := oc.prepareStreamingRun( + prep, cleanup := oc.prepareStreamingRun( context.Background(), zerolog.Nop(), nil, nil, meta, - []openai.ChatCompletionMessageParamUnion{}, ) defer cleanup() diff --git a/bridges/ai/streaming_input_conversion.go b/bridges/ai/streaming_input_conversion.go index b360733a..6c021fc8 100644 --- a/bridges/ai/streaming_input_conversion.go +++ b/bridges/ai/streaming_input_conversion.go @@ -1,38 +1,10 @@ package ai -import ( - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/responses" - - bridgesdk "github.com/beeper/agentremote/sdk" -) - -func (oc *AIClient) convertToResponsesInput(messages []openai.ChatCompletionMessageParamUnion, _ *PortalMetadata) responses.ResponseInputParam { - return bridgesdk.PromptContextToResponsesInput(bridgesdk.ChatMessagesToPromptContext(messages)) -} - -// hasAudioContent checks if the prompt contains audio content -func hasAudioContent(messages []openai.ChatCompletionMessageParamUnion) bool { - for _, msg := range messages { - if msg.OfUser != nil && len(msg.OfUser.Content.OfArrayOfContentParts) > 0 { - for _, part := range msg.OfUser.Content.OfArrayOfContentParts { - if part.OfInputAudio != nil { - return true - } - } - } - } - return false -} - -// hasMultimodalContent checks if the prompt contains non-text content (image, file, audio). -func hasMultimodalContent(messages []openai.ChatCompletionMessageParamUnion) bool { - for _, msg := range messages { - if msg.OfUser != nil && len(msg.OfUser.Content.OfArrayOfContentParts) > 0 { - for _, part := range msg.OfUser.Content.OfArrayOfContentParts { - if part.OfImageURL != nil || part.OfFile != nil || part.OfInputAudio != nil { - return true - } +func promptHasMultimodalContent(prompt PromptContext) bool { + for _, msg := range prompt.Messages { + for _, block := range msg.Blocks { + if block.Type == PromptBlockImage { + return true } } } diff --git a/bridges/ai/streaming_output_items_test.go b/bridges/ai/streaming_output_items_test.go index 0ebac79b..80d6c293 100644 --- a/bridges/ai/streaming_output_items_test.go +++ b/bridges/ai/streaming_output_items_test.go @@ -60,7 +60,7 @@ func TestDeriveToolDescriptorForOutputItem_FunctionCallParsesArgumentsJSON(t *te func TestUpsertActiveToolFromDescriptor_RecreatesNilMapEntry(t *testing.T) { oc := &AIClient{} state := newStreamingState(context.Background(), nil, "") - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) state.turn = conv.StartTurn(context.Background(), nil, nil) activeTools := newStreamToolRegistry() activeTools.byKey[streamToolItemKey("item_123")] = nil diff --git a/bridges/ai/streaming_request_tools_test.go b/bridges/ai/streaming_request_tools_test.go index 02b7eba3..2fab20fe 100644 --- a/bridges/ai/streaming_request_tools_test.go +++ b/bridges/ai/streaming_request_tools_test.go @@ -2,48 +2,32 @@ package ai import ( "context" + "slices" "testing" - "time" - - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" ) -func testToolSelectionClient(supportsToolCalling bool) *AIClient { - return &AIClient{ - connector: &OpenAIConnector{ - Config: Config{ - Tools: ToolProvidersConfig{ - Search: &SearchConfig{ - Exa: ProviderExaConfig{APIKey: "test"}, - }, - }, - }, - }, - UserLogin: &bridgev2.UserLogin{UserLogin: &database.UserLogin{Metadata: &UserLoginMetadata{ - ModelCache: &ModelCache{ - Models: []ModelInfo{{ - ID: "openai/gpt-5.2", - SupportsToolCalling: supportsToolCalling, - }}, - LastRefresh: time.Now().Unix(), - CacheDuration: 3600, - }, - }}}, - } -} - func TestSelectedStreamingToolDescriptorsSkipsAllToolsWhenModelCannotCallTools(t *testing.T) { - meta := agentModeTestMeta("beeper") - meta.RuntimeModelOverride = "openai/gpt-5.2" + meta := modelModeTestMeta("openai/gpt-5.2") - withTools := testToolSelectionClient(true).selectedStreamingToolDescriptors(context.Background(), meta, false) + withTools := testBuiltinToolClient(true, true, true).selectedStreamingToolDescriptors(context.Background(), meta, false) if len(withTools) == 0 { t.Fatal("expected tool descriptors when tool calling is supported") } - withoutTools := testToolSelectionClient(false).selectedStreamingToolDescriptors(context.Background(), meta, false) + withoutTools := testBuiltinToolClient(false, true, true).selectedStreamingToolDescriptors(context.Background(), meta, false) if len(withoutTools) != 0 { t.Fatalf("expected no tool descriptors when tool calling is unsupported, got %#v", withoutTools) } } + +func TestBuildResponsesAgentLoopParams_ModelRoomUsesModelPreset(t *testing.T) { + meta := modelModeTestMeta("openai/gpt-5.2") + client := testBuiltinToolClient(true, true, true) + + params := client.buildResponsesAgentLoopParams(context.Background(), meta, "system prompt", nil, false) + got := responsesToolNames(params.Tools) + want := []string{toolNameSessionStatus, toolNameWebFetch, ToolNameWebSearch} + if !slices.Equal(got, want) { + t.Fatalf("unexpected response tool list: got %v want %v", got, want) + } +} diff --git a/bridges/ai/streaming_responses_api.go b/bridges/ai/streaming_responses_api.go index 53921829..861907e3 100644 --- a/bridges/ai/streaming_responses_api.go +++ b/bridges/ai/streaming_responses_api.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/packages/param" "github.com/openai/openai-go/v3/packages/ssestream" "github.com/openai/openai-go/v3/responses" @@ -28,7 +27,6 @@ type responsesTurnAdapter struct { agentLoopProviderBase params responses.ResponseNewParams initialized bool - hasFollowUp bool rsc *responseStreamContext } @@ -38,8 +36,8 @@ func (a *responsesTurnAdapter) TrackRoomRunStreaming() bool { func (a *responsesTurnAdapter) startInitialRound(ctx context.Context) (*ssestream.Stream[responses.ResponseStreamEventUnion], error) { if !a.initialized { - input := a.oc.convertToResponsesInput(a.messages, a.meta) - a.params = a.oc.buildResponsesAgentLoopParams(ctx, a.meta, input, false) + input := promptContextToResponsesInput(a.prompt) + a.params = a.oc.buildResponsesAgentLoopParams(ctx, a.meta, a.prompt.SystemPrompt, input, false) if len(a.params.Tools) > 0 { zerolog.Ctx(ctx).Debug().Int("count", len(a.params.Tools)).Msg("Added streaming turn tools") } @@ -52,9 +50,6 @@ func (a *responsesTurnAdapter) startInitialRound(ctx context.Context) (*ssestrea if stream == nil { return nil, errors.New("responses streaming not available") } - if a.params.Input.OfInputItemList != nil { - a.state.baseInput = a.params.Input.OfInputItemList - } return stream, nil } @@ -89,17 +84,13 @@ func (a *responsesTurnAdapter) startContinuationRound(ctx context.Context) (*sse approvalInputs = append(approvalInputs, item) } - continuationParams := a.oc.buildContinuationParams(ctx, state, a.meta, pendingOutputs, approvalInputs) - if continuationInput := continuationParams.Input.OfInputItemList; continuationInput != nil { - state.baseInput = slices.Clone(continuationInput) - } + continuationParams := a.oc.buildContinuationParams(ctx, &a.prompt, state, a.meta, pendingOutputs, approvalInputs) state.needsTextSeparator = true stream := a.oc.api.Responses.NewStreaming(ctx, continuationParams) if stream == nil { return nil, continuationParams, errors.New("continuation streaming not available") } - a.hasFollowUp = false state.clearContinuationState() return stream, continuationParams, nil } @@ -120,11 +111,11 @@ func (a *responsesTurnAdapter) RunAgentTurn( stream, err = a.startInitialRound(ctx) params = a.params if err != nil { - logResponsesFailure(a.log, err, params, a.meta, a.messages, "stream_init") + logResponsesFailure(a.log, err, params, a.meta, a.prompt, "stream_init") return false, nil, &PreDeltaError{Err: err} } } else { - if len(state.pendingFunctionOutputs) == 0 && len(state.pendingMcpApprovals) == 0 && !a.hasFollowUp { + if len(state.pendingFunctionOutputs) == 0 && len(state.pendingMcpApprovals) == 0 && len(state.pendingSteeringPrompts) == 0 { return false, nil, nil } if round > maxAgentLoopToolTurns { @@ -135,14 +126,14 @@ func (a *responsesTurnAdapter) RunAgentTurn( a.log.Debug(). Int("pending_outputs", len(state.pendingFunctionOutputs)). Int("pending_approvals", len(state.pendingMcpApprovals)). - Int("base_input_items", len(state.baseInput)). + Int("prompt_messages", len(a.prompt.Messages)). Msg("Continuing stateless response with pending tool actions") stream, params, err = a.startContinuationRound(ctx) if err != nil { if errors.Is(err, context.Canceled) { return false, nil, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, state, a.meta, "cancelled", err) } - logResponsesFailure(a.log, err, params, a.meta, a.messages, "continuation_init") + logResponsesFailure(a.log, err, params, a.meta, a.prompt, "continuation_init") return false, nil, a.oc.finishStreamingWithFailure(ctx, a.log, a.portal, state, a.meta, "error", err) } } @@ -158,7 +149,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( if round > 0 { stage = "continuation_event_error" } - logResponsesFailure(a.log, evtErr, params, a.meta, a.messages, stage) + logResponsesFailure(a.log, evtErr, params, a.meta, a.prompt, stage) } return done, cle, evtErr }, @@ -167,7 +158,7 @@ func (a *responsesTurnAdapter) RunAgentTurn( if round > 0 { stage = "continuation_err" } - logResponsesFailure(a.log, stepErr, params, a.meta, a.messages, stage) + logResponsesFailure(a.log, stepErr, params, a.meta, a.prompt, stage) return a.oc.handleResponsesStreamErr(ctx, a.portal, state, a.meta, stepErr, round == 0) }, ) @@ -175,25 +166,16 @@ func (a *responsesTurnAdapter) RunAgentTurn( return false, cle, err } if done { - return state != nil && (len(state.pendingFunctionOutputs) > 0 || len(state.pendingMcpApprovals) > 0), nil, nil + return state != nil && (len(state.pendingFunctionOutputs) > 0 || len(state.pendingMcpApprovals) > 0 || len(state.pendingSteeringPrompts) > 0), nil, nil } - return state != nil && (len(state.pendingFunctionOutputs) > 0 || len(state.pendingMcpApprovals) > 0), nil, nil + return state != nil && (len(state.pendingFunctionOutputs) > 0 || len(state.pendingMcpApprovals) > 0 || len(state.pendingSteeringPrompts) > 0), nil, nil } func (a *responsesTurnAdapter) FinalizeAgentLoop(ctx context.Context) { a.oc.finalizeResponsesStream(ctx, a.log, a.portal, a.state, a.meta) } -func (a *responsesTurnAdapter) ContinueAgentLoop(messages []openai.ChatCompletionMessageParamUnion) { - if len(messages) == 0 { - return - } - a.messages = append(a.messages, messages...) - a.state.baseInput = append(a.state.baseInput, a.oc.convertToResponsesInput(messages, a.meta)...) - a.hasFollowUp = true -} - // processResponseStreamEvent handles a single Responses API stream event. // Returns done=true when the caller's loop should break (error/fatal), along with // any context-length error or general error. The caller is responsible for @@ -464,12 +446,12 @@ func (oc *AIClient) handleProviderToolCompleted( } // runResponsesAgentLoop handles the Responses API provider adapter under the canonical agent loop. -func (oc *AIClient) runResponsesAgentLoop( +func (oc *AIClient) runResponsesAgentLoopPrompt( ctx context.Context, evt *event.Event, portal *bridgev2.Portal, meta *PortalMetadata, - messages []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) (bool, *ContextLengthError, error) { portalID := "" if portal != nil { @@ -478,8 +460,8 @@ func (oc *AIClient) runResponsesAgentLoop( log := zerolog.Ctx(ctx).With(). Str("portal_id", portalID). Logger() - return oc.runAgentLoop(ctx, log, evt, portal, meta, messages, func(prep streamingRunPrep, pruned []openai.ChatCompletionMessageParamUnion) agentLoopProvider { - base := newAgentLoopProviderBase(oc, log, portal, meta, prep, pruned) + return oc.runAgentLoop(ctx, log, evt, portal, meta, prompt, func(prep streamingRunPrep, prompt PromptContext) agentLoopProvider { + base := newAgentLoopProviderBase(oc, log, portal, meta, prep, prompt) return &responsesTurnAdapter{ agentLoopProviderBase: base, rsc: &responseStreamContext{ diff --git a/bridges/ai/streaming_responses_input_test.go b/bridges/ai/streaming_responses_input_test.go index de9bcbb2..7a630800 100644 --- a/bridges/ai/streaming_responses_input_test.go +++ b/bridges/ai/streaming_responses_input_test.go @@ -9,44 +9,32 @@ import ( ) func TestConvertToResponsesInput_RolesAndToolOutput(t *testing.T) { - oc := &AIClient{} - messages := []openai.ChatCompletionMessageParamUnion{ - openai.DeveloperMessage("dev instructions"), openai.UserMessage("hello"), openai.ToolMessage("tool output", "call_123"), } - input := oc.convertToResponsesInput(messages, nil) - if len(input) != 3 { - t.Fatalf("expected 3 input items, got %d", len(input)) - } - - if input[0].OfMessage == nil { - t.Fatalf("expected developer message input, got nil") - } - if input[0].OfMessage.Role != responses.EasyInputMessageRoleDeveloper { - t.Fatalf("expected developer role, got %s", input[0].OfMessage.Role) + input := promptContextToResponsesInput(chatMessagesToPromptContext(messages)) + if len(input) != 2 { + t.Fatalf("expected 2 input items, got %d", len(input)) } - if input[1].OfMessage == nil || input[1].OfMessage.Role != responses.EasyInputMessageRoleUser { - t.Fatalf("expected user message input for item 2") + if input[0].OfMessage == nil || input[0].OfMessage.Role != responses.EasyInputMessageRoleUser { + t.Fatalf("expected user message input first") } - if input[2].OfFunctionCallOutput == nil { - t.Fatalf("expected function_call_output input for item 3") + if input[1].OfFunctionCallOutput == nil { + t.Fatalf("expected function_call_output input second") } - if input[2].OfFunctionCallOutput.CallID != "call_123" { - t.Fatalf("expected call_id call_123, got %s", input[2].OfFunctionCallOutput.CallID) + if input[1].OfFunctionCallOutput.CallID != "call_123" { + t.Fatalf("expected call_id call_123, got %s", input[1].OfFunctionCallOutput.CallID) } - if input[2].OfFunctionCallOutput.Output.OfString.Value != "tool output" { - t.Fatalf("expected tool output to match, got %q", input[2].OfFunctionCallOutput.Output.OfString.Value) + if input[1].OfFunctionCallOutput.Output.OfString.Value != "tool output" { + t.Fatalf("expected tool output to match, got %q", input[1].OfFunctionCallOutput.Output.OfString.Value) } } func TestConvertToResponsesInput_AssistantToolCalls(t *testing.T) { - oc := &AIClient{} - messages := []openai.ChatCompletionMessageParamUnion{{ OfAssistant: &openai.ChatCompletionAssistantMessageParam{ Content: openai.ChatCompletionAssistantMessageParamContentUnion{ @@ -65,7 +53,7 @@ func TestConvertToResponsesInput_AssistantToolCalls(t *testing.T) { }, }} - input := oc.convertToResponsesInput(messages, nil) + input := promptContextToResponsesInput(chatMessagesToPromptContext(messages)) if len(input) != 2 { t.Fatalf("expected 2 input items, got %d", len(input)) } diff --git a/bridges/ai/streaming_state.go b/bridges/ai/streaming_state.go index 460e4f00..18990d95 100644 --- a/bridges/ai/streaming_state.go +++ b/bridges/ai/streaming_state.go @@ -38,7 +38,6 @@ type streamingState struct { reasoningTokens int64 totalTokens int64 - baseInput responses.ResponseInputParam accumulated strings.Builder reasoning strings.Builder toolCalls []ToolCallMetadata diff --git a/bridges/ai/streaming_tool_selection.go b/bridges/ai/streaming_tool_selection.go index 1b5e70d0..4b46a378 100644 --- a/bridges/ai/streaming_tool_selection.go +++ b/bridges/ai/streaming_tool_selection.go @@ -2,15 +2,68 @@ package ai import "context" -// selectedBuiltinToolsForTurn returns builtin tools exposed to the model for a turn. -func (oc *AIClient) selectedBuiltinToolsForTurn(ctx context.Context, meta *PortalMetadata) []ToolDefinition { - if meta == nil || !oc.getModelCapabilitiesForMeta(ctx, meta).SupportsToolCalling { +type builtinToolPreset string + +const ( + builtinToolPresetNone builtinToolPreset = "" + builtinToolPresetModel builtinToolPreset = "model" + builtinToolPresetAgent builtinToolPreset = "agent" +) + +var modelChatBuiltinToolNames = []string{ + ToolNameWebSearch, + toolNameWebFetch, + toolNameSessionStatus, +} + +func selectToolDefinitionsByName(available []ToolDefinition, names []string) []ToolDefinition { + if len(available) == 0 || len(names) == 0 { return nil } + availableByName := make(map[string]ToolDefinition, len(available)) + for _, tool := range available { + if tool.Name == "" { + continue + } + availableByName[tool.Name] = tool + } + + selected := make([]ToolDefinition, 0, len(names)) + for _, name := range names { + tool, ok := availableByName[name] + if !ok { + continue + } + selected = append(selected, tool) + } + return selected +} + +func (oc *AIClient) builtinToolPresetForTurn(ctx context.Context, meta *PortalMetadata) builtinToolPreset { + if meta == nil || !oc.getModelCapabilitiesForMeta(ctx, meta).SupportsToolCalling { + return builtinToolPresetNone + } if resolveAgentID(meta) == "" { + return builtinToolPresetModel + } + return builtinToolPresetAgent +} + +// selectedBuiltinToolsForTurn returns builtin tools exposed to the model for a turn. +func (oc *AIClient) selectedBuiltinToolsForTurn(ctx context.Context, meta *PortalMetadata) []ToolDefinition { + preset := oc.builtinToolPresetForTurn(ctx, meta) + if preset == builtinToolPresetNone { return nil } - return oc.enabledBuiltinToolsForModel(ctx, meta) + enabledTools := oc.enabledBuiltinToolsForModel(ctx, meta) + switch preset { + case builtinToolPresetModel: + return selectToolDefinitionsByName(enabledTools, modelChatBuiltinToolNames) + case builtinToolPresetAgent: + return enabledTools + default: + return nil + } } diff --git a/bridges/ai/streaming_tool_selection_test.go b/bridges/ai/streaming_tool_selection_test.go index 054a6345..b9d6322b 100644 --- a/bridges/ai/streaming_tool_selection_test.go +++ b/bridges/ai/streaming_tool_selection_test.go @@ -2,6 +2,7 @@ package ai import ( "context" + "slices" "testing" "time" @@ -9,13 +10,29 @@ import ( "maunium.net/go/mautrix/bridgev2/database" ) -func TestSelectedBuiltinToolsForTurn_AgentRoomExposesBuiltinTools(t *testing.T) { - client := &AIClient{ +func testBuiltinToolClient(supportsToolCalling, searchConfigured, fetchConfigured bool) *AIClient { + searchCfg := &SearchConfig{ + Exa: ProviderExaConfig{Enabled: boolPtr(false)}, + } + if searchConfigured { + searchCfg.Exa = ProviderExaConfig{ + Enabled: boolPtr(true), + APIKey: "test-key", + } + } + + fetchCfg := &FetchConfig{ + Exa: ProviderExaConfig{Enabled: boolPtr(false)}, + Direct: ProviderDirectConfig{Enabled: boolPtr(fetchConfigured)}, + } + + return &AIClient{ connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Search: &SearchConfig{ - Exa: ProviderExaConfig{APIKey: "test-key"}, + Web: &WebToolsConfig{ + Search: searchCfg, + Fetch: fetchCfg, }, }, }, @@ -24,13 +41,29 @@ func TestSelectedBuiltinToolsForTurn_AgentRoomExposesBuiltinTools(t *testing.T) ModelCache: &ModelCache{ Models: []ModelInfo{{ ID: "openai/gpt-5.2", - SupportsToolCalling: true, + SupportsToolCalling: supportsToolCalling, }}, LastRefresh: time.Now().Unix(), CacheDuration: 3600, }, }}}, } +} + +func toolDefinitionNames(tools []ToolDefinition) []string { + names := make([]string, 0, len(tools)) + for _, tool := range tools { + if tool.Name == "" { + continue + } + names = append(names, tool.Name) + } + slices.Sort(names) + return names +} + +func TestSelectedBuiltinToolsForTurn_AgentRoomExposesBuiltinTools(t *testing.T) { + client := testBuiltinToolClient(true, true, true) meta := agentModeTestMeta("beeper") meta.RuntimeModelOverride = "openai/gpt-5.2" @@ -41,23 +74,45 @@ func TestSelectedBuiltinToolsForTurn_AgentRoomExposesBuiltinTools(t *testing.T) } } -func TestSelectedBuiltinToolsForTurn_ModelRoomGetsNoTools(t *testing.T) { - client := &AIClient{ - connector: &OpenAIConnector{ - Config: Config{ - Tools: ToolProvidersConfig{ - Search: &SearchConfig{ - Exa: ProviderExaConfig{APIKey: "test-key"}, - }, - }, - }, - }, +func TestSelectedBuiltinToolsForTurn_ModelRoomExposesModelPreset(t *testing.T) { + client := testBuiltinToolClient(true, true, true) + meta := modelModeTestMeta("openai/gpt-5.2") + + got := toolDefinitionNames(client.selectedBuiltinToolsForTurn(context.Background(), meta)) + want := []string{toolNameSessionStatus, toolNameWebFetch, ToolNameWebSearch} + if !slices.Equal(got, want) { + t.Fatalf("unexpected model room builtin tools: got %v want %v", got, want) + } +} + +func TestSelectedBuiltinToolsForTurn_ModelRoomOmitsUnavailableWebTools(t *testing.T) { + client := testBuiltinToolClient(true, false, false) + meta := modelModeTestMeta("openai/gpt-5.2") + + got := toolDefinitionNames(client.selectedBuiltinToolsForTurn(context.Background(), meta)) + want := []string{toolNameSessionStatus} + if !slices.Equal(got, want) { + t.Fatalf("unexpected model room builtin tools with web tools disabled: got %v want %v", got, want) } +} + +func TestSelectedBuiltinToolsForTurn_ModelRoomOmitsOnlyUnavailableSearch(t *testing.T) { + client := testBuiltinToolClient(true, false, true) + meta := modelModeTestMeta("openai/gpt-5.2") + + got := toolDefinitionNames(client.selectedBuiltinToolsForTurn(context.Background(), meta)) + want := []string{toolNameSessionStatus, toolNameWebFetch} + if !slices.Equal(got, want) { + t.Fatalf("unexpected model room builtin tools with search disabled: got %v want %v", got, want) + } +} - meta := &PortalMetadata{} +func TestSelectedBuiltinToolsForTurn_ModelRoomWithoutToolCallingGetsNoTools(t *testing.T) { + client := testBuiltinToolClient(false, true, true) + meta := modelModeTestMeta("openai/gpt-5.2") got := client.selectedBuiltinToolsForTurn(context.Background(), meta) if len(got) != 0 { - t.Fatalf("expected no builtin tools when room has no assigned agent, got %d", len(got)) + t.Fatalf("expected no builtin tools when model does not support tool calling, got %d", len(got)) } } diff --git a/bridges/ai/streaming_ui_tools_test.go b/bridges/ai/streaming_ui_tools_test.go index 6d910464..0c59617f 100644 --- a/bridges/ai/streaming_ui_tools_test.go +++ b/bridges/ai/streaming_ui_tools_test.go @@ -46,7 +46,7 @@ func TestRequestTurnApprovalWithoutApprovalFlowReturnsHandle(t *testing.T) { func TestStartStreamingMCPApprovalAutoApprovedEmitsApprovalRequest(t *testing.T) { oc := newTestAIClient("@owner:example.com") state := newStreamingState(context.Background(), nil, "") - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) state.turn = conv.StartTurn(context.Background(), nil, nil) handle, err := oc.startStreamingMCPApproval(context.Background(), nil, state, ToolApprovalParams{ diff --git a/bridges/ai/subagent_announce.go b/bridges/ai/subagent_announce.go index ffc92219..c09f4d91 100644 --- a/bridges/ai/subagent_announce.go +++ b/bridges/ai/subagent_announce.go @@ -7,12 +7,9 @@ import ( "strings" "time" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/id" - - bridgesdk "github.com/beeper/agentremote/sdk" ) func formatDurationShort(valueMs int64) string { @@ -144,9 +141,9 @@ func (oc *AIClient) runSubagentCompletion( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) (bool, error) { - responseFn, logLabel := oc.selectAgentLoopRunFunc(meta, PromptContext{PromptContext: bridgesdk.ChatMessagesToPromptContext(prompt)}) + responseFn, logLabel := oc.selectAgentLoopRunFunc(meta, prompt) return oc.responseWithRetry(ctx, nil, portal, meta, prompt, responseFn, logLabel) } @@ -155,7 +152,7 @@ func (oc *AIClient) runSubagentAndAnnounce( run *subagentRun, childPortal *bridgev2.Portal, childMeta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) { if oc == nil || run == nil || childPortal == nil || childMeta == nil { return diff --git a/bridges/ai/subagent_spawn.go b/bridges/ai/subagent_spawn.go index 1af0f146..3fa98c0c 100644 --- a/bridges/ai/subagent_spawn.go +++ b/bridges/ai/subagent_spawn.go @@ -9,7 +9,6 @@ import ( "time" "github.com/google/uuid" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" @@ -331,15 +330,13 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P } eventID := agentremote.NewEventID("subagent") - promptContext, err := oc.buildContextWithLinkContext(ctx, childPortal, childMeta, task, nil, eventID) + promptContext, err := oc.buildCurrentTurnWithLinks(ctx, childPortal, childMeta, task, nil, eventID) if err != nil { return tools.JSONResult(map[string]any{ "status": "error", "error": err.Error(), }), nil } - promptMessages := oc.promptContextToDispatchMessages(ctx, childPortal, childMeta, promptContext) - userMessage := &database.Message{ ID: agentremote.MatrixMessageID(eventID), MXID: eventID, @@ -351,7 +348,6 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P Timestamp: time.Now(), } setCanonicalTurnDataFromPromptMessages(userMessage.Metadata.(*MessageMetadata), promptTail(promptContext, 1)) - ensureCanonicalUserMessage(userMessage) if _, err := oc.UserLogin.Bridge.GetGhostByID(ctx, userMessage.SenderID); err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure user ghost before saving subagent task message") } @@ -371,7 +367,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P Timeout: runTimeout, } oc.registerSubagentRun(run) - oc.startSubagentRun(ctx, run, childPortal, childMeta, promptMessages) + oc.startSubagentRun(ctx, run, childPortal, childMeta, promptContext) payload := map[string]any{ "status": "accepted", @@ -393,7 +389,7 @@ func (oc *AIClient) startSubagentRun( run *subagentRun, childPortal *bridgev2.Portal, childMeta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, + prompt PromptContext, ) { if run == nil || childPortal == nil || childMeta == nil { return diff --git a/bridges/ai/system_prompts.go b/bridges/ai/system_prompts.go index 17726a3c..54c63ce3 100644 --- a/bridges/ai/system_prompts.go +++ b/bridges/ai/system_prompts.go @@ -4,9 +4,9 @@ import ( "context" "strings" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" + integrationruntime "github.com/beeper/agentremote/pkg/integrations/runtime" runtimeparse "github.com/beeper/agentremote/pkg/runtime" ) @@ -33,10 +33,6 @@ func buildGroupIntro(roomName string, activation string) string { return strings.Join(lines, " ") + " Address the specific sender noted in the message context." } -func buildVerboseSystemHint(_ *PortalMetadata) string { - return "" -} - func buildSessionIdentityHint(portal *bridgev2.Portal, _ *PortalMetadata) string { if portal == nil { return "" @@ -55,59 +51,83 @@ func buildSessionIdentityHint(portal *bridgev2.Portal, _ *PortalMetadata) string return "sessionKey: " + session } -func (oc *AIClient) buildAdditionalSystemPrompts( +func (oc *AIClient) buildAdditionalSystemPromptText( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, -) []openai.ChatCompletionMessageParamUnion { - return oc.additionalSystemMessages(ctx, portal, meta) +) string { + return joinPromptFragments( + oc.buildAdditionalSystemPromptCoreText(ctx, portal, meta), + oc.buildMemoryPromptContextText(ctx, portal, meta), + ) } -func (oc *AIClient) buildSystemMessages( +func (oc *AIClient) buildSystemPromptText( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, -) []openai.ChatCompletionMessageParamUnion { - var msgs []openai.ChatCompletionMessageParamUnion - systemPrompt := oc.effectiveAgentPrompt(ctx, portal, meta) - if systemPrompt == "" { - systemPrompt = oc.effectivePrompt(meta) +) string { + base := oc.effectiveAgentPrompt(ctx, portal, meta) + if base == "" { + base = oc.effectivePrompt(meta) } - if systemPrompt != "" { - msgs = append(msgs, openai.SystemMessage(systemPrompt)) + return joinPromptFragments(base, oc.buildAdditionalSystemPromptText(ctx, portal, meta)) +} + +func (oc *AIClient) buildConversationSystemPromptText( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, + includeGreeting bool, +) string { + base := oc.buildSystemPromptText(ctx, portal, meta) + if !includeGreeting { + return base } - msgs = append(msgs, oc.buildAdditionalSystemPrompts(ctx, portal, meta)...) - return msgs + return joinPromptFragments(sessionGreetingFragment(ctx, portal, meta, oc.log), base) } -func (oc *AIClient) buildAdditionalSystemPromptsCore( +func (oc *AIClient) buildAdditionalSystemPromptCoreText( ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata, -) []openai.ChatCompletionMessageParamUnion { - var out []openai.ChatCompletionMessageParamUnion +) string { + var out []string if meta != nil && portal != nil && oc.isGroupChat(ctx, portal) { activation := oc.resolveGroupActivation(meta) intro := buildGroupIntro(oc.matrixRoomDisplayName(ctx, portal), activation) if strings.TrimSpace(intro) != "" { - out = append(out, openai.SystemMessage(intro)) - } - } - - if meta != nil { - if verboseHint := buildVerboseSystemHint(meta); verboseHint != "" { - out = append(out, openai.SystemMessage(verboseHint)) + out = append(out, intro) } } if accountHint := oc.buildDesktopAccountHintPrompt(ctx); accountHint != "" { - out = append(out, openai.SystemMessage(accountHint)) + out = append(out, accountHint) } if ident := buildSessionIdentityHint(portal, meta); ident != "" { - out = append(out, openai.SystemMessage(ident)) + out = append(out, ident) } - return out + return joinPromptFragments(out...) +} + +func (oc *AIClient) buildMemoryPromptContextText( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, +) string { + if oc == nil || len(oc.integrationModules) == 0 { + return "" + } + module := oc.integrationModules["memory"] + augmentor, ok := module.(integrationruntime.PromptContextIntegration) + if !ok || augmentor == nil { + return "" + } + return strings.TrimSpace(augmentor.PromptContextText(ctx, integrationruntime.PromptScope{ + Portal: portal, + Meta: meta, + })) } diff --git a/bridges/ai/text_files.go b/bridges/ai/text_files.go index d9447ecc..43515f1d 100644 --- a/bridges/ai/text_files.go +++ b/bridges/ai/text_files.go @@ -5,6 +5,9 @@ import ( "encoding/binary" "errors" "fmt" + "os" + "os/exec" + "path/filepath" "regexp" "strings" "unicode" @@ -22,6 +25,7 @@ var ( ) const maxTextFileBytes = 5 * 1024 * 1024 +const maxPDFFileBytes = 50 * 1024 * 1024 var textFileMimeTypesMap = map[string]event.CapabilitySupportLevel{ "text/plain": event.CapLevelFullySupported, @@ -192,7 +196,47 @@ func (oc *AIClient) downloadTextFile(ctx context.Context, mediaURL string, encry return trimmed, truncated, nil } -func buildTextFileMessage(caption string, hasUserCaption bool, filename string, mimeType string, content string, _ bool) string { +func (oc *AIClient) downloadPDFFile(ctx context.Context, mediaURL string, encryptedFile *event.EncryptedFileInfo, mimeType string) (string, bool, error) { + data, _, err := oc.downloadMediaBytes(ctx, mediaURL, encryptedFile, maxPDFFileBytes, mimeType) + if err != nil { + return "", false, err + } + + tempDir, err := os.MkdirTemp("", "ai-bridge-pdf-*") + if err != nil { + return "", false, fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + inputPath := filepath.Join(tempDir, "input.pdf") + if err := os.WriteFile(inputPath, data, 0o600); err != nil { + return "", false, fmt.Errorf("write temp pdf: %w", err) + } + + cmd := exec.CommandContext(ctx, "pdftotext", "-layout", "-enc", "UTF-8", inputPath, "-") + output, err := cmd.Output() + if err != nil { + if errors.Is(err, exec.ErrNotFound) { + return "", false, fmt.Errorf("pdftotext not found: install poppler-utils (or poppler) and ensure pdftotext is on PATH: %w", err) + } + if exitErr, ok := err.(*exec.ExitError); ok { + msg := strings.TrimSpace(string(exitErr.Stderr)) + if msg != "" { + return "", false, fmt.Errorf("pdftotext failed: %s", msg) + } + } + return "", false, fmt.Errorf("pdftotext failed: %w", err) + } + + text := strings.TrimSpace(string(output)) + if text == "" { + return "[No extractable text]", false, nil + } + trimmed, truncated := trimTextForModel(text) + return trimmed, truncated, nil +} + +func buildTextFileMessage(caption string, hasUserCaption bool, filename string, mimeType string, content string, truncated bool) string { if !hasUserCaption { caption = "" } @@ -206,10 +250,15 @@ func buildTextFileMessage(caption string, hasUserCaption bool, filename string, mimeType = "text/plain" } + truncAttr := "" + if truncated { + truncAttr = " truncated=\"true\"" + } block := fmt.Sprintf( - "\n%s\n", + "\n%s\n", xmlEscapeAttr(filename), xmlEscapeAttr(mimeType), + truncAttr, escapeFileBlockContent(content), ) diff --git a/bridges/ai/token_resolver.go b/bridges/ai/token_resolver.go index 9bd783f7..7addbea0 100644 --- a/bridges/ai/token_resolver.go +++ b/bridges/ai/token_resolver.go @@ -26,6 +26,13 @@ type ServiceConfig struct { type ServiceConfigMap map[string]ServiceConfig +func (oc *OpenAIConnector) modelProviderConfig(provider string) ModelProviderConfig { + if oc == nil || oc.Config.Models == nil { + return ModelProviderConfig{} + } + return oc.Config.Models.Provider(provider) +} + func trimToken(value string) string { return strings.TrimSpace(value) } @@ -101,10 +108,8 @@ func (oc *OpenAIConnector) resolveProxyRoot(meta *UserLoginMetadata) string { if oc == nil { return "" } - if meta != nil { - if raw := strings.TrimSpace(meta.BaseURL); raw != "" { - return normalizeProxyBaseURL(raw) - } + if raw := loginCredentialBaseURL(meta); raw != "" { + return normalizeProxyBaseURL(raw) } return "" } @@ -118,7 +123,7 @@ func (oc *OpenAIConnector) resolveExaProxyBaseURL(meta *UserLoginMetadata) strin } func (oc *OpenAIConnector) resolveOpenAIBaseURL() string { - base := strings.TrimSpace(oc.Config.Providers.OpenAI.BaseURL) + base := strings.TrimSpace(oc.modelProviderConfig(ProviderOpenAI).BaseURL) if base == "" { base = defaultOpenAIBaseURL } @@ -126,7 +131,7 @@ func (oc *OpenAIConnector) resolveOpenAIBaseURL() string { } func (oc *OpenAIConnector) resolveOpenRouterBaseURL() string { - base := strings.TrimSpace(oc.Config.Providers.OpenRouter.BaseURL) + base := strings.TrimSpace(oc.modelProviderConfig(ProviderOpenRouter).BaseURL) if base == "" { base = defaultOpenRouterBaseURL } @@ -140,9 +145,9 @@ func (oc *OpenAIConnector) resolveServiceConfig(meta *UserLoginMetadata) Service } if meta.Provider == ProviderMagicProxy { - base := normalizeProxyBaseURL(meta.BaseURL) + base := normalizeProxyBaseURL(loginCredentialBaseURL(meta)) if base != "" { - token := trimToken(meta.APIKey) + token := trimToken(loginCredentialAPIKey(meta)) services[serviceOpenRouter] = ServiceConfig{ BaseURL: joinProxyPath(base, "/openrouter/v1"), APIKey: token, @@ -183,73 +188,73 @@ func (oc *OpenAIConnector) resolveProviderAPIKey(meta *UserLoginMetadata) string } switch meta.Provider { case ProviderMagicProxy: - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenRouter) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenRouter) } case ProviderOpenRouter: - if key := trimToken(oc.Config.Providers.OpenRouter.APIKey); key != "" { + if key := trimToken(oc.modelProviderConfig(ProviderOpenRouter).APIKey); key != "" { return key } - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenRouter) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenRouter) } case ProviderOpenAI: - if key := trimToken(oc.Config.Providers.OpenAI.APIKey); key != "" { + if key := trimToken(oc.modelProviderConfig(ProviderOpenAI).APIKey); key != "" { return key } - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenAI) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenAI) } default: - return trimToken(meta.APIKey) + return trimToken(loginCredentialAPIKey(meta)) } return "" } func (oc *OpenAIConnector) resolveOpenAIAPIKey(meta *UserLoginMetadata) string { - if key := trimToken(oc.Config.Providers.OpenAI.APIKey); key != "" { + if key := trimToken(oc.modelProviderConfig(ProviderOpenAI).APIKey); key != "" { return key } if meta == nil { return "" } if meta.Provider == ProviderOpenAI { - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenAI) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenAI) } return "" } func (oc *OpenAIConnector) resolveOpenRouterAPIKey(meta *UserLoginMetadata) string { - if key := trimToken(oc.Config.Providers.OpenRouter.APIKey); key != "" { + if key := trimToken(oc.modelProviderConfig(ProviderOpenRouter).APIKey); key != "" { return key } if meta == nil { return "" } if meta.Provider == ProviderOpenRouter { - if key := trimToken(meta.APIKey); key != "" { + if key := trimToken(loginCredentialAPIKey(meta)); key != "" { return key } } if meta.Provider == ProviderMagicProxy { - return trimToken(meta.APIKey) + return trimToken(loginCredentialAPIKey(meta)) } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenRouter) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenRouter) } return "" } @@ -261,21 +266,21 @@ func loginTokenForService(meta *UserLoginMetadata, service string) string { switch service { case serviceOpenAI: if meta.Provider == ProviderOpenAI { - return trimToken(meta.APIKey) + return trimToken(loginCredentialAPIKey(meta)) } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenAI) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenAI) } case serviceOpenRouter: if meta.Provider == ProviderOpenRouter || meta.Provider == ProviderMagicProxy { - return trimToken(meta.APIKey) + return trimToken(loginCredentialAPIKey(meta)) } - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.OpenRouter) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.OpenRouter) } case serviceExa: - if meta.ServiceTokens != nil { - return trimToken(meta.ServiceTokens.Exa) + if tokens := loginCredentialServiceTokens(meta); tokens != nil { + return trimToken(tokens.Exa) } } return "" diff --git a/bridges/ai/tool_availability_configured_test.go b/bridges/ai/tool_availability_configured_test.go index 3e940aab..43e9dd78 100644 --- a/bridges/ai/tool_availability_configured_test.go +++ b/bridges/ai/tool_availability_configured_test.go @@ -21,7 +21,7 @@ func TestToolAvailable_WebSearch_RequiresAnyProviderKey(t *testing.T) { connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Search: &SearchConfig{}, + Web: &WebToolsConfig{Search: &SearchConfig{}}, }, }, }, @@ -48,9 +48,9 @@ func TestToolAvailable_WebSearch_WithProviderKey(t *testing.T) { connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Search: &SearchConfig{ + Web: &WebToolsConfig{Search: &SearchConfig{ Exa: ProviderExaConfig{APIKey: "test"}, - }, + }}, }, }, }, @@ -71,9 +71,9 @@ func TestToolAvailable_WebFetch_DirectDisabledAndNoExaKey(t *testing.T) { connector: &OpenAIConnector{ Config: Config{ Tools: ToolProvidersConfig{ - Fetch: &FetchConfig{ + Web: &WebToolsConfig{Fetch: &FetchConfig{ Direct: ProviderDirectConfig{Enabled: boolPtr(false)}, - }, + }}, }, }, }, diff --git a/bridges/ai/tool_configured.go b/bridges/ai/tool_configured.go index bb8e096d..8b61096a 100644 --- a/bridges/ai/tool_configured.go +++ b/bridges/ai/tool_configured.go @@ -17,10 +17,10 @@ func (oc *AIClient) effectiveSearchConfig(_ context.Context) *search.Config { return effectiveToolConfig( oc, func(connector *OpenAIConnector) *search.Config { - if connector == nil { + if connector == nil || connector.Config.Tools.Web == nil { return nil } - return mapSearchConfig(connector.Config.Tools.Search) + return mapSearchConfig(connector.Config.Tools.Web.Search) }, applyLoginTokensToSearchConfig, func(cfg *search.Config) *search.Config { return search.ApplyEnvDefaults(cfg).WithDefaults() }, @@ -31,10 +31,10 @@ func (oc *AIClient) effectiveFetchConfig(_ context.Context) *fetch.Config { return effectiveToolConfig( oc, func(connector *OpenAIConnector) *fetch.Config { - if connector == nil { + if connector == nil || connector.Config.Tools.Web == nil { return nil } - return mapFetchConfig(connector.Config.Tools.Fetch) + return mapFetchConfig(connector.Config.Tools.Web.Fetch) }, applyLoginTokensToFetchConfig, func(cfg *fetch.Config) *fetch.Config { return fetch.ApplyEnvDefaults(cfg).WithDefaults() }, diff --git a/bridges/ai/tools.go b/bridges/ai/tools.go index 84a59f3a..401e845e 100644 --- a/bridges/ai/tools.go +++ b/bridges/ai/tools.go @@ -1361,7 +1361,7 @@ func resolveOpenAITTSBaseURL(btc *BridgeToolContext, providerBaseURL string) (st } } } - if root := normalizeProxyBaseURL(meta.BaseURL); root != "" { + if root := normalizeProxyBaseURL(loginCredentialBaseURL(meta)); root != "" { return joinProxyPath(root, "/openai/v1"), true } diff --git a/bridges/ai/tools_analyze_image.go b/bridges/ai/tools_analyze_image.go index 8b03aaaa..2c49f790 100644 --- a/bridges/ai/tools_analyze_image.go +++ b/bridges/ai/tools_analyze_image.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/beeper/agentremote/pkg/shared/media" - bridgesdk "github.com/beeper/agentremote/sdk" ) // executeAnalyzeImage analyzes an image with a custom prompt using vision capabilities. @@ -80,7 +79,7 @@ func executeAnalyzeImage(ctx context.Context, args map[string]any) (string, erro return "", errors.New("unsupported URL scheme, must be http://, https://, mxc://, or data URL") } - ctxPrompt := PromptContext{PromptContext: bridgesdk.UserPromptContext( + ctxPrompt := UserPromptContext( PromptBlock{ Type: PromptBlockImage, ImageB64: imageB64, @@ -90,7 +89,7 @@ func executeAnalyzeImage(ctx context.Context, args map[string]any) (string, erro Type: PromptBlockText, Text: prompt, }, - )} + ) // Call the AI provider for vision analysis resp, err := btc.Client.provider.Generate(ctx, GenerateParams{ diff --git a/bridges/ai/tools_search_fetch.go b/bridges/ai/tools_search_fetch.go index e610e884..25b8548d 100644 --- a/bridges/ai/tools_search_fetch.go +++ b/bridges/ai/tools_search_fetch.go @@ -213,7 +213,7 @@ func applyExaProxyDefaultsTo(baseURL *string, apiKey *string, meta *UserLoginMet } if *apiKey == "" { if meta != nil && meta.Provider == ProviderMagicProxy { - if token := strings.TrimSpace(meta.APIKey); token != "" { + if token := loginCredentialAPIKey(meta); token != "" { *apiKey = token } } diff --git a/bridges/ai/tools_search_fetch_test.go b/bridges/ai/tools_search_fetch_test.go index 7fcea4bd..a8006da4 100644 --- a/bridges/ai/tools_search_fetch_test.go +++ b/bridges/ai/tools_search_fetch_test.go @@ -10,8 +10,10 @@ func TestApplyLoginTokensToSearchConfig_MagicProxyForcesExa(t *testing.T) { oc := &OpenAIConnector{} meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - APIKey: "magic-token", - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + APIKey: "magic-token", + BaseURL: "https://bai.bt.hn/team/proxy", + }, } cfg := &search.Config{ Provider: search.ProviderExa, @@ -60,7 +62,9 @@ func TestApplyLoginTokensToSearchConfig_DefaultExaEndpointDoesNotForceExa(t *tes oc := &OpenAIConnector{} meta := &UserLoginMetadata{ Provider: ProviderOpenRouter, - APIKey: "openrouter-token", + Credentials: &LoginCredentials{ + APIKey: "openrouter-token", + }, } cfg := &search.Config{ Provider: search.ProviderExa, diff --git a/bridges/ai/tools_tts_test.go b/bridges/ai/tools_tts_test.go index d7924773..14ef62ad 100644 --- a/bridges/ai/tools_tts_test.go +++ b/bridges/ai/tools_tts_test.go @@ -25,7 +25,9 @@ func newTTSTestBridgeContext(meta *UserLoginMetadata, oc *OpenAIConnector) *Brid func TestResolveOpenAITTSBaseURLMagicProxy(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - BaseURL: "https://bai.bt.hn/team/proxy", + Credentials: &LoginCredentials{ + BaseURL: "https://bai.bt.hn/team/proxy", + }, } btc := newTTSTestBridgeContext(meta, &OpenAIConnector{}) @@ -42,7 +44,9 @@ func TestResolveOpenAITTSBaseURLMagicProxy(t *testing.T) { func TestResolveOpenAITTSBaseURLMagicProxyWithoutConnector(t *testing.T) { meta := &UserLoginMetadata{ Provider: ProviderMagicProxy, - BaseURL: "https://bai.bt.hn/team/proxy/openrouter/v1", + Credentials: &LoginCredentials{ + BaseURL: "https://bai.bt.hn/team/proxy/openrouter/v1", + }, } btc := newTTSTestBridgeContext(meta, nil) @@ -60,9 +64,9 @@ func TestResolveOpenAITTSBaseURLOpenAIProviderUsesConfiguredBase(t *testing.T) { meta := &UserLoginMetadata{Provider: ProviderOpenAI} oc := &OpenAIConnector{ Config: Config{ - Providers: ProvidersConfig{ - OpenAI: ProviderConfig{BaseURL: "https://openai.example/v1"}, - }, + Models: &ModelsConfig{Providers: map[string]ModelProviderConfig{ + ProviderOpenAI: {BaseURL: "https://openai.example/v1"}, + }}, }, } btc := newTTSTestBridgeContext(meta, oc) diff --git a/bridges/ai/turn_validation.go b/bridges/ai/turn_validation.go deleted file mode 100644 index ed3cf51e..00000000 --- a/bridges/ai/turn_validation.go +++ /dev/null @@ -1,208 +0,0 @@ -package ai - -import ( - "strings" - - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/packages/param" -) - -// IsGoogleModel returns true if the model ID looks like a Google/Gemini model. -func IsGoogleModel(modelID string) bool { - lower := strings.ToLower(modelID) - return strings.HasPrefix(lower, "google/") || - strings.HasPrefix(lower, "gemini") || - strings.Contains(lower, "/gemini") -} - -// SanitizeGoogleTurnOrdering fixes prompt ordering for Google models: -// - Merges consecutive user messages -// - Merges consecutive assistant messages -// - Prepends a synthetic user turn if history starts with an assistant message -func SanitizeGoogleTurnOrdering(prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion { - if len(prompt) == 0 { - return prompt - } - - // Separate system messages (keep at front) from conversation messages - var system []openai.ChatCompletionMessageParamUnion - var conversation []openai.ChatCompletionMessageParamUnion - for _, msg := range prompt { - if chatMessageRole(msg) == "system" { - system = append(system, msg) - } else { - conversation = append(conversation, msg) - } - } - - if len(conversation) == 0 { - return prompt - } - - // Merge consecutive same-role messages - merged := mergeConsecutiveSameRole(conversation) - - // If the first non-system message is assistant, prepend a synthetic user turn - if len(merged) > 0 && chatMessageRole(merged[0]) == "assistant" { - merged = append([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("(continued from previous session)"), - }, merged...) - } - - return append(system, merged...) -} - -// mergeConsecutiveSameRole combines adjacent messages with the same role. -// For user messages, it preserves multimodal content parts (images, etc.) -// by collecting all parts from the run into a single OfArrayOfContentParts message. -// For assistant messages, it concatenates text bodies with double newlines. -func mergeConsecutiveSameRole(msgs []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion { - if len(msgs) <= 1 { - return msgs - } - - var result []openai.ChatCompletionMessageParamUnion - i := 0 - for i < len(msgs) { - role := chatMessageRole(msgs[i]) - j := i + 1 - for j < len(msgs) && chatMessageRole(msgs[j]) == role { - j++ - } - - // Single message, no merging needed — keep as-is. - if j == i+1 { - result = append(result, msgs[i]) - i = j - continue - } - - // Multiple consecutive messages with the same role — merge them. - run := msgs[i:j] - switch role { - case "user": - result = append(result, mergeUserMessages(run)) - case "assistant": - // Assistant messages are always text-only (images go in synthetic user messages). - var body string - for _, m := range run { - nextBody := chatMessageBody(m) - if nextBody != "" { - if body != "" { - body += "\n\n" - } - body += nextBody - } - } - result = append(result, openai.AssistantMessage(body)) - default: - // For other roles, concatenate text. - var body string - for _, m := range run { - nextBody := chatMessageBody(m) - if nextBody != "" { - if body != "" { - body += "\n\n" - } - body += nextBody - } - } - result = append(result, openai.UserMessage(body)) - } - i = j - } - return result -} - -// mergeUserMessages merges a run of consecutive user messages into one, -// preserving multimodal content parts (OfArrayOfContentParts) if any message has them. -func mergeUserMessages(run []openai.ChatCompletionMessageParamUnion) openai.ChatCompletionMessageParamUnion { - // Check if any message in the run has multimodal parts. - hasMultimodal := false - for _, m := range run { - if m.OfUser != nil && len(m.OfUser.Content.OfArrayOfContentParts) > 0 { - hasMultimodal = true - break - } - } - - // If no multimodal content, do the simple text merge. - if !hasMultimodal { - var body string - for _, m := range run { - nextBody := chatMessageBody(m) - if nextBody != "" { - if body != "" { - body += "\n\n" - } - body += nextBody - } - } - return openai.UserMessage(body) - } - - // Collect all content parts, converting plain-text messages to text parts. - var allParts []openai.ChatCompletionContentPartUnionParam - for idx, m := range run { - if m.OfUser == nil { - continue - } - if len(m.OfUser.Content.OfArrayOfContentParts) > 0 { - // Add a separator text part between merged messages (except the first). - if idx > 0 && len(allParts) > 0 { - allParts = append(allParts, openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{Text: "\n\n"}, - }) - } - allParts = append(allParts, m.OfUser.Content.OfArrayOfContentParts...) - } else if m.OfUser.Content.OfString.Value != "" { - if len(allParts) > 0 { - allParts = append(allParts, openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{Text: "\n\n"}, - }) - } - allParts = append(allParts, openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{Text: m.OfUser.Content.OfString.Value}, - }) - } - } - - return openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfArrayOfContentParts: allParts, - }, - }, - } -} - -// chatMessageRole extracts the role string from a ChatCompletionMessageParamUnion. -// GetRole() returns empty strings at construction time (constant types marshal lazily), -// so we check which Of* field is populated instead. -func chatMessageRole(msg openai.ChatCompletionMessageParamUnion) string { - if !param.IsOmitted(msg.OfSystem) { - return "system" - } - if !param.IsOmitted(msg.OfUser) { - return "user" - } - if !param.IsOmitted(msg.OfAssistant) { - return "assistant" - } - if !param.IsOmitted(msg.OfTool) { - return "tool" - } - if !param.IsOmitted(msg.OfDeveloper) { - return "developer" - } - return "user" -} - -// chatMessageBody extracts the text body from a ChatCompletionMessageParamUnion. -func chatMessageBody(msg openai.ChatCompletionMessageParamUnion) string { - c := msg.GetContent() - if s, ok := c.AsAny().(*string); ok && s != nil { - return *s - } - return "" -} diff --git a/bridges/ai/turn_validation_test.go b/bridges/ai/turn_validation_test.go deleted file mode 100644 index 80ca9af6..00000000 --- a/bridges/ai/turn_validation_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package ai - -import ( - "testing" - - "github.com/openai/openai-go/v3" -) - -func TestIsGoogleModel(t *testing.T) { - tests := []struct { - model string - want bool - }{ - {"google/gemini-2.5-flash", true}, - {"google/gemini-3.1-pro-preview", true}, - {"gemini-pro", true}, - {"openrouter/google/gemini-flash", true}, - {"anthropic/claude-sonnet-4.5", false}, - {"openai/gpt-5", false}, - {"", false}, - } - for _, tt := range tests { - if got := IsGoogleModel(tt.model); got != tt.want { - t.Errorf("IsGoogleModel(%q) = %v, want %v", tt.model, got, tt.want) - } - } -} - -func TestSanitizeGoogleTurnOrdering_MergesConsecutiveUser(t *testing.T) { - prompt := []openai.ChatCompletionMessageParamUnion{ - openai.SystemMessage("system"), - openai.UserMessage("hello"), - openai.UserMessage("world"), - openai.AssistantMessage("hi"), - } - result := SanitizeGoogleTurnOrdering(prompt) - if hasConsecutiveUserOrAssistantRoles(result) { - t.Fatal("expected sanitized prompt to be valid") - } - // system + merged-user + assistant = 3 - if len(result) != 3 { - t.Fatalf("expected 3 messages, got %d", len(result)) - } -} - -func TestSanitizeGoogleTurnOrdering_PrependsSyntheticUser(t *testing.T) { - prompt := []openai.ChatCompletionMessageParamUnion{ - openai.SystemMessage("system"), - openai.AssistantMessage("I was speaking"), - openai.UserMessage("ok"), - } - result := SanitizeGoogleTurnOrdering(prompt) - if hasConsecutiveUserOrAssistantRoles(result) { - t.Fatal("expected sanitized prompt to be valid") - } - // system + synthetic-user + assistant + user = 4 - if len(result) != 4 { - t.Fatalf("expected 4 messages, got %d", len(result)) - } - // First non-system should be user - if chatMessageRole(result[1]) != "user" { - t.Fatalf("expected synthetic user message, got %s", chatMessageRole(result[1])) - } -} - -func TestSanitizeGoogleTurnOrdering_Empty(t *testing.T) { - result := SanitizeGoogleTurnOrdering(nil) - if result != nil { - t.Fatal("expected nil for nil input") - } -} - -func TestSanitizeGoogleTurnOrdering_AlreadyValid(t *testing.T) { - prompt := []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("hello"), - openai.AssistantMessage("hi"), - } - result := SanitizeGoogleTurnOrdering(prompt) - if len(result) != 2 { - t.Fatalf("expected 2 messages for already-valid prompt, got %d", len(result)) - } -} - -func TestChatMessageRole(t *testing.T) { - tests := []struct { - msg openai.ChatCompletionMessageParamUnion - want string - }{ - {openai.SystemMessage("sys"), "system"}, - {openai.UserMessage("usr"), "user"}, - {openai.AssistantMessage("ast"), "assistant"}, - } - for _, tt := range tests { - if got := chatMessageRole(tt.msg); got != tt.want { - t.Errorf("chatMessageRole() = %q, want %q", got, tt.want) - } - } -} - -func hasConsecutiveUserOrAssistantRoles(prompt []openai.ChatCompletionMessageParamUnion) bool { - lastRole := "" - for _, msg := range prompt { - role := chatMessageRole(msg) - if role == "system" { - continue - } - if role == lastRole && (role == "user" || role == "assistant") { - return true - } - lastRole = role - } - return false -} diff --git a/bridges/codex/connector.go b/bridges/codex/connector.go index ff52d318..1dfc66f3 100644 --- a/bridges/codex/connector.go +++ b/bridges/codex/connector.go @@ -29,7 +29,7 @@ type CodexConnector struct { *agentremote.ConnectorBase br *bridgev2.Bridge Config Config - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*CodexClient, *Config] db *dbutil.Database clientsMu sync.Mutex diff --git a/bridges/codex/connector_test.go b/bridges/codex/connector_test.go index c6fd518d..f3240237 100644 --- a/bridges/codex/connector_test.go +++ b/bridges/codex/connector_test.go @@ -38,6 +38,13 @@ func TestGetCapabilitiesEnablesContactListProvisioning(t *testing.T) { } } +func TestGetNameUsesDefaultCommandPrefixBeforeStartup(t *testing.T) { + conn := NewConnector() + if got := conn.GetName().DefaultCommandPrefix; got != "!ai" { + t.Fatalf("expected default command prefix !ai, got %q", got) + } +} + func TestHostAuthLoginIDUsesDedicatedPrefix(t *testing.T) { conn := NewConnector() mxid := id.UserID("@alice:example.com") diff --git a/bridges/codex/constructors.go b/bridges/codex/constructors.go index f36d67df..8e9cf395 100644 --- a/bridges/codex/constructors.go +++ b/bridges/codex/constructors.go @@ -33,7 +33,7 @@ func NewConnector() *CodexConnector { Description: "Provide externally managed ChatGPT id/access tokens.", }, } - cc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + cc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*CodexClient, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "codex", Description: "A Matrix↔Codex bridge built on mautrix-go bridgev2.", ProtocolID: "ai-codex", @@ -65,7 +65,7 @@ func NewConnector() *CodexConnector { BeeperBridgeType: "codex", DefaultPort: 29346, DefaultCommandPrefix: func() string { - return cc.Config.Bridge.CommandPrefix + return bridgesdk.ResolveCommandPrefix(cc.Config.Bridge.CommandPrefix, "!ai") }, FillBridgeInfo: func(portal *bridgev2.Portal, content *event.BridgeEventContent) { if portal == nil { @@ -76,10 +76,10 @@ func NewConnector() *CodexConnector { ExampleConfig: exampleNetworkConfig, ConfigData: &cc.Config, ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, AcceptLogin: func(login *bridgev2.UserLogin) (bool, string) { return bridgesdk.AcceptProviderLogin(login, ProviderCodex, "This bridge only supports Codex logins.", cc.codexEnabled, "Codex integration is disabled in the configuration.", func(login *bridgev2.UserLogin) string { return loginMetadata(login).Provider diff --git a/bridges/codex/stream_mapping_test.go b/bridges/codex/stream_mapping_test.go index 74ca2150..ed1c3393 100644 --- a/bridges/codex/stream_mapping_test.go +++ b/bridges/codex/stream_mapping_test.go @@ -23,7 +23,7 @@ func attachTestTurn(state *streamingState, portal *bridgev2.Portal) { if state == nil { return } - conv := bridgesdk.NewConversation(context.Background(), nil, portal, bridgev2.EventSender{}, &bridgesdk.Config{}, nil) + conv := bridgesdk.NewConversation(context.Background(), nil, portal, bridgev2.EventSender{}, &bridgesdk.Config[*CodexClient, *struct{}]{}, nil) turn := conv.StartTurn(context.Background(), nil, nil) turn.SetID(state.turnID) state.turn = turn diff --git a/bridges/dummybridge/bridge.go b/bridges/dummybridge/bridge.go index c59c0e1d..656be714 100644 --- a/bridges/dummybridge/bridge.go +++ b/bridges/dummybridge/bridge.go @@ -31,15 +31,14 @@ func (dc *DummyBridgeConnector) loggerForLogin(login *bridgev2.UserLogin) zerolo return login.Log.With().Str("component", "dummybridge").Logger() } -func sessionFromAny(session any) (*dummySession, error) { - dummy, ok := session.(*dummySession) - if !ok || dummy == nil || dummy.login == nil { +func requireSession(session *dummySession) (*dummySession, error) { + if session == nil || session.login == nil { return nil, errors.New("dummybridge session is unavailable") } - return dummy, nil + return session, nil } -func (dc *DummyBridgeConnector) onConnect(ctx context.Context, info *bridgesdk.LoginInfo) (any, error) { +func (dc *DummyBridgeConnector) onConnect(ctx context.Context, info *bridgesdk.LoginInfo) (*dummySession, error) { if info == nil || info.Login == nil { return nil, errors.New("missing login info") } @@ -58,12 +57,10 @@ func (dc *DummyBridgeConnector) onConnect(ctx context.Context, info *bridgesdk.L }, nil } -func (dc *DummyBridgeConnector) onDisconnect(session any) { - _, _ = sessionFromAny(session) -} +func (dc *DummyBridgeConnector) onDisconnect(_ *dummySession) {} -func (dc *DummyBridgeConnector) getContactList(ctx context.Context, session any) ([]*bridgev2.ResolveIdentifierResponse, error) { - dummy, err := sessionFromAny(session) +func (dc *DummyBridgeConnector) getContactList(ctx context.Context, session *dummySession) ([]*bridgev2.ResolveIdentifierResponse, error) { + dummy, err := requireSession(session) if err != nil { return nil, err } @@ -74,8 +71,8 @@ func (dc *DummyBridgeConnector) getContactList(ctx context.Context, session any) return []*bridgev2.ResolveIdentifierResponse{resp}, nil } -func (dc *DummyBridgeConnector) searchUsers(ctx context.Context, session any, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { - dummy, err := sessionFromAny(session) +func (dc *DummyBridgeConnector) searchUsers(ctx context.Context, session *dummySession, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { + dummy, err := requireSession(session) if err != nil { return nil, err } @@ -99,8 +96,8 @@ func (dc *DummyBridgeConnector) searchUsers(ctx context.Context, session any, qu return nil, nil } -func (dc *DummyBridgeConnector) resolveIdentifier(ctx context.Context, session any, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { - dummy, err := sessionFromAny(session) +func (dc *DummyBridgeConnector) resolveIdentifier(ctx context.Context, session *dummySession, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { + dummy, err := requireSession(session) if err != nil { return nil, err } diff --git a/bridges/dummybridge/connector.go b/bridges/dummybridge/connector.go index b9338554..e90e139b 100644 --- a/bridges/dummybridge/connector.go +++ b/bridges/dummybridge/connector.go @@ -21,7 +21,7 @@ type DummyBridgeConnector struct { *agentremote.ConnectorBase br *bridgev2.Bridge Config Config - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*dummySession, *Config] clientsMu sync.Mutex clients map[networkid.UserLoginID]bridgev2.NetworkAPI @@ -31,7 +31,7 @@ type DummyBridgeConnector struct { func NewConnector() *DummyBridgeConnector { dc := &DummyBridgeConnector{} - dc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + dc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*dummySession, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "dummybridge", Description: "A synthetic Matrix↔DummyBridge demo bridge built on the AgentRemote SDK.", ProtocolID: "ai-dummybridge", @@ -52,15 +52,15 @@ func NewConnector() *DummyBridgeConnector { BeeperBridgeType: "dummybridge", DefaultPort: 29349, DefaultCommandPrefix: func() string { - return dc.Config.Bridge.CommandPrefix + return bridgesdk.ResolveCommandPrefix(dc.Config.Bridge.CommandPrefix, "!dummybridge") }, ExampleConfig: exampleNetworkConfig, ConfigData: &dc.Config, ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, AcceptLogin: func(login *bridgev2.UserLogin) (bool, string) { return bridgesdk.AcceptProviderLogin(login, ProviderDummyBridge, "This bridge only supports DummyBridge logins.", dc.enabled, "DummyBridge integration is disabled in the configuration.", func(login *bridgev2.UserLogin) string { return loginMetadata(login).Provider diff --git a/bridges/dummybridge/connector_test.go b/bridges/dummybridge/connector_test.go index 336ba315..263f60e4 100644 --- a/bridges/dummybridge/connector_test.go +++ b/bridges/dummybridge/connector_test.go @@ -38,3 +38,10 @@ func TestGetCapabilitiesExposeProvisioningSearchAndContacts(t *testing.T) { t.Fatal("expected search provisioning to be enabled") } } + +func TestGetNameUsesDefaultCommandPrefixBeforeStartup(t *testing.T) { + conn := NewConnector() + if got := conn.GetName().DefaultCommandPrefix; got != "!dummybridge" { + t.Fatalf("expected default command prefix !dummybridge, got %q", got) + } +} diff --git a/bridges/dummybridge/runtime.go b/bridges/dummybridge/runtime.go index 13021b2e..c01c43e3 100644 --- a/bridges/dummybridge/runtime.go +++ b/bridges/dummybridge/runtime.go @@ -247,7 +247,7 @@ const ( randomActionTransient randomActionKind = "data_transient" ) -func (dc *DummyBridgeConnector) onMessage(session any, conv *bridgesdk.Conversation, msg *bridgesdk.Message, turn *bridgesdk.Turn) error { +func (dc *DummyBridgeConnector) onMessage(session *dummySession, conv *bridgesdk.Conversation, msg *bridgesdk.Message, turn *bridgesdk.Turn) error { if conv == nil || turn == nil || msg == nil { return nil } @@ -265,7 +265,7 @@ func (dc *DummyBridgeConnector) onMessage(session any, conv *bridgesdk.Conversat if cmd.Name == "help" { return conv.SendNotice(turn.Context(), helpText()) } - dummy, err := sessionFromAny(session) + dummy, err := requireSession(session) if err != nil { return err } diff --git a/bridges/dummybridge/runtime_test.go b/bridges/dummybridge/runtime_test.go index 53b5b1fc..60290bd4 100644 --- a/bridges/dummybridge/runtime_test.go +++ b/bridges/dummybridge/runtime_test.go @@ -56,7 +56,7 @@ func (r *advancingRuntime) sleepFn(_ context.Context, delay time.Duration) error } func newTestTurn() *bridgesdk.Turn { - cfg := &bridgesdk.Config{ + cfg := &bridgesdk.Config[*dummySession, *struct{}]{ ProviderIdentity: bridgesdk.ProviderIdentity{IDPrefix: "dummybridge", StatusNetwork: "dummybridge"}, } // These tests only exercise turn-local streaming behavior. Login/portal are diff --git a/bridges/openclaw/connector.go b/bridges/openclaw/connector.go index a497397d..d6e27867 100644 --- a/bridges/openclaw/connector.go +++ b/bridges/openclaw/connector.go @@ -23,7 +23,7 @@ type OpenClawConnector struct { *agentremote.ConnectorBase br *bridgev2.Bridge Config Config - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*OpenClawClient, *Config] clientsMu sync.Mutex clients map[networkid.UserLoginID]bridgev2.NetworkAPI @@ -41,7 +41,7 @@ type openClawLoginPrefill struct { func NewConnector() *OpenClawConnector { oc := &OpenClawConnector{} - oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*OpenClawClient, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "openclaw", Description: "A Matrix↔OpenClaw bridge built on mautrix-go bridgev2.", ProtocolID: "ai-openclaw", @@ -75,10 +75,10 @@ func NewConnector() *OpenClawConnector { ExampleConfig: exampleNetworkConfig, ConfigData: &oc.Config, ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, NetworkCapabilities: func() *bridgev2.NetworkGeneralCapabilities { caps := agentremote.DefaultNetworkCapabilities() caps.DisappearingMessages = false diff --git a/bridges/openclaw/manager.go b/bridges/openclaw/manager.go index 45f04526..96afc01f 100644 --- a/bridges/openclaw/manager.go +++ b/bridges/openclaw/manager.go @@ -593,7 +593,11 @@ func (m *openClawManager) HandleMatrixMessage(ctx context.Context, msg *bridgev2 return nil, err } if meta.OpenClawDMCreatedFromContact && meta.OpenClawSessionID == "" && isOpenClawSyntheticDMSessionKey(meta.OpenClawSessionKey) { - go m.syncSessions(m.client.BackgroundContext(ctx)) + go func() { + if err := m.syncSessions(m.client.BackgroundContext(ctx)); err != nil { + m.client.Log().Debug().Err(err).Str("session_key", meta.OpenClawSessionKey).Msg("Failed to refresh OpenClaw sessions after synthetic DM message") + } + }() } return &bridgev2.MatrixMessageResponse{Pending: true}, nil } diff --git a/bridges/openclaw/stream_test.go b/bridges/openclaw/stream_test.go index c03ccfcb..dccca886 100644 --- a/bridges/openclaw/stream_test.go +++ b/bridges/openclaw/stream_test.go @@ -64,7 +64,7 @@ func (testMatrixAPI) GetEvent(context.Context, id.RoomID, id.EventID) (*event.Ev } func newOpenClawTestTurn(turnID string) *bridgesdk.Turn { - conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &bridgesdk.Config{}, nil) + conv := bridgesdk.NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &bridgesdk.Config[*OpenClawClient, *struct{}]{}, nil) turn := conv.StartTurn(context.Background(), nil, nil) turn.SetID(turnID) return turn @@ -301,7 +301,7 @@ func TestEmitStreamPartSerializesTurnCreation(t *testing.T) { oc := newOpenClawTestClient(map[string]*openClawStreamState{}) oc.UserLogin = &bridgev2.UserLogin{Bridge: &bridgev2.Bridge{Bot: testMatrixAPI{}}} oc.connector = &OpenClawConnector{} - oc.connector.sdkConfig = &bridgesdk.Config{} + oc.connector.sdkConfig = &bridgesdk.Config[*OpenClawClient, *Config]{} original := openClawNewSDKStreamTurn defer func() { openClawNewSDKStreamTurn = original }() diff --git a/bridges/opencode/connector.go b/bridges/opencode/connector.go index 6b44af91..1221566d 100644 --- a/bridges/opencode/connector.go +++ b/bridges/opencode/connector.go @@ -22,7 +22,7 @@ type OpenCodeConnector struct { *agentremote.ConnectorBase br *bridgev2.Bridge Config Config - sdkConfig *bridgesdk.Config + sdkConfig *bridgesdk.Config[*OpenCodeClient, *Config] clientsMu sync.Mutex clients map[networkid.UserLoginID]bridgev2.NetworkAPI @@ -42,7 +42,7 @@ func NewConnector() *OpenCodeConnector { Description: "Let the bridge spawn and manage OpenCode processes for you.", }, } - oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams{ + oc.sdkConfig = bridgesdk.NewStandardConnectorConfig(bridgesdk.StandardConnectorConfigParams[*OpenCodeClient, *Config, *PortalMetadata, *MessageMetadata, *UserLoginMetadata, *GhostMetadata]{ Name: "opencode", Description: "A Matrix↔OpenCode bridge built on mautrix-go bridgev2.", ProtocolID: "ai-opencode", @@ -50,7 +50,7 @@ func NewConnector() *OpenCodeConnector { ProviderIdentity: bridgesdk.ProviderIdentity{IDPrefix: "opencode", LogKey: "opencode_msg_id", StatusNetwork: "opencode"}, ClientCacheMu: &oc.clientsMu, ClientCache: &oc.clients, - GetCapabilities: func(session any, _ *bridgesdk.Conversation) *bridgesdk.RoomFeatures { + GetCapabilities: func(_ *OpenCodeClient, _ *bridgesdk.Conversation) *bridgesdk.RoomFeatures { return &bridgesdk.RoomFeatures{Custom: openCodeMatrixRoomFeatures()} }, InitConnector: func(bridge *bridgev2.Bridge) { @@ -72,10 +72,10 @@ func NewConnector() *OpenCodeConnector { ExampleConfig: exampleNetworkConfig, ConfigData: &oc.Config, ConfigUpgrader: configupgrade.SimpleUpgrader(upgradeConfig), - NewPortal: func() any { return &PortalMetadata{} }, - NewMessage: func() any { return &MessageMetadata{} }, - NewLogin: func() any { return &UserLoginMetadata{} }, - NewGhost: func() any { return &GhostMetadata{} }, + NewPortal: func() *PortalMetadata { return &PortalMetadata{} }, + NewMessage: func() *MessageMetadata { return &MessageMetadata{} }, + NewLogin: func() *UserLoginMetadata { return &UserLoginMetadata{} }, + NewGhost: func() *GhostMetadata { return &GhostMetadata{} }, AcceptLogin: func(login *bridgev2.UserLogin) (bool, string) { return bridgesdk.AcceptProviderLogin(login, ProviderOpenCode, "This bridge only supports OpenCode logins.", oc.openCodeEnabled, "OpenCode integration is disabled in the configuration.", func(login *bridgev2.UserLogin) string { return loginMetadata(login).Provider diff --git a/bridges/opencode/opencode_canonical_stream.go b/bridges/opencode/opencode_canonical_stream.go index 84fc4c4f..4664a333 100644 --- a/bridges/opencode/opencode_canonical_stream.go +++ b/bridges/opencode/opencode_canonical_stream.go @@ -66,7 +66,6 @@ func (m *OpenCodeManager) syncAssistantTextPart(ctx context.Context, inst *openC "delta": text, }) inst.appendPartTextContent(part.SessionID, part.ID, kind, text) - delivered = text } } else if missing, ok := strings.CutPrefix(text, delivered); ok && missing != "" { m.bridge.emitOpenCodeStreamEvent(ctx, portal, turnID, agentID, map[string]any{ diff --git a/cmd/generate-models/main.go b/cmd/generate-models/main.go index 3626de85..03cc7572 100644 --- a/cmd/generate-models/main.go +++ b/cmd/generate-models/main.go @@ -26,104 +26,44 @@ var modelConfig = struct { Aliases map[string]string }{ Models: map[string]string{ - // Anthropic (Claude) via OpenRouter - "anthropic/claude-haiku-4.5": "Claude Haiku 4.5", - "anthropic/claude-opus-4.1": "Claude 4.1 Opus", - "anthropic/claude-opus-4.5": "Claude Opus 4.5", - "anthropic/claude-opus-4.6": "Claude Opus 4.6", - "anthropic/claude-sonnet-4": "Claude 4 Sonnet", - "anthropic/claude-sonnet-4.5": "Claude Sonnet 4.5", - "anthropic/claude-sonnet-4.6": "Claude Sonnet 4.6", - - // DeepSeek - "deepseek/deepseek-chat-v3-0324": "DeepSeek v3 (0324)", - "deepseek/deepseek-chat-v3.1": "DeepSeek v3.1", - "deepseek/deepseek-v3.1-terminus": "DeepSeek v3.1 Terminus", - "deepseek/deepseek-v3.2": "DeepSeek v3.2", - "deepseek/deepseek-r1": "DeepSeek R1 (Original)", - "deepseek/deepseek-r1-0528": "DeepSeek R1 (0528)", - "deepseek/deepseek-r1-distill-qwen-32b": "DeepSeek R1 (Qwen Distilled)", - - // Gemini (Google) via OpenRouter - "google/gemini-2.0-flash-001": "Gemini 2.0 Flash", - "google/gemini-2.0-flash-lite-001": "Gemini 2.0 Flash Lite", - "google/gemini-2.5-flash": "Gemini 2.5 Flash", - "google/gemini-2.5-flash-image": "Nano Banana", - "google/gemini-2.5-flash-lite": "Gemini 2.5 Flash Lite", - "google/gemini-2.5-pro": "Gemini 2.5 Pro", - "google/gemini-3-flash-preview": "Gemini 3 Flash", - "google/gemini-3-pro-image-preview": "Nano Banana Pro", - "google/gemini-3.1-flash-lite-preview": "Gemini 3.1 Flash Lite", - "google/gemini-3.1-pro-preview": "Gemini 3.1 Pro", - // For Gemini image generation, use Nano Banana / Nano Banana Pro. - - // GLM (Z.AI) - "z-ai/glm-4.5": "GLM 4.5", - "z-ai/glm-4.5-air": "GLM 4.5 Air", - "z-ai/glm-4.5v": "GLM 4.5V", - "z-ai/glm-4.6": "GLM 4.6", - "z-ai/glm-4.6v": "GLM 4.6V", - "z-ai/glm-4.7": "GLM 4.7", - "z-ai/glm-5": "GLM 5", - - // Kimi (Moonshot) - "moonshotai/kimi-k2": "Kimi K2 (0711)", - "moonshotai/kimi-k2-0905": "Kimi K2 (0905)", - "moonshotai/kimi-k2.5": "Kimi K2.5", - - // Llama (Meta) - "meta-llama/llama-3.3-70b-instruct": "Llama 3.3 70B", - "meta-llama/llama-4-maverick": "Llama 4 Maverick", - "meta-llama/llama-4-scout": "Llama 4 Scout", - - // MiniMax - "minimax/minimax-m2": "MiniMax M2", - "minimax/minimax-m2.1": "MiniMax M2.1", - "minimax/minimax-m2.5": "MiniMax M2.5", - - // OpenAI models via OpenRouter - "openai/gpt-4.1": "GPT-4.1", - "openai/gpt-4.1-mini": "GPT-4.1 Mini", - "openai/gpt-4.1-nano": "GPT-4.1 Nano", - "openai/gpt-4o-mini": "GPT-4o-mini", - "openai/gpt-5": "GPT-5", - "openai/gpt-5-image": "GPT ImageGen 1.5", - "openai/gpt-5-image-mini": "GPT ImageGen", - "openai/gpt-5-mini": "GPT-5 mini", - "openai/gpt-5-nano": "GPT-5 nano", - "openai/gpt-5.1": "GPT-5.1", - "openai/gpt-5.2": "GPT-5.2", - "openai/gpt-5.2-pro": "GPT-5.2 Pro", - "openai/gpt-5.3-chat": "GPT-5.3 Instant", - "openai/gpt-5.4": "GPT-5.4", - "openai/gpt-oss-20b": "GPT OSS 20B", - "openai/gpt-oss-120b": "GPT OSS 120B", - "openai/o3": "o3", - "openai/o3-mini": "o3-mini", - "openai/o3-pro": "o3 Pro", - "openai/o4-mini": "o4-mini", - - // Qwen (Alibaba) - "qwen/qwen2.5-vl-32b-instruct": "Qwen 2.5 32B", - "qwen/qwen3-32b": "Qwen 3 32B", - "qwen/qwen3-235b-a22b": "Qwen 3 235B", - "qwen/qwen3-coder": "Qwen 3 Coder", - - // xAI (Grok) - "x-ai/grok-3": "Grok 3", - "x-ai/grok-3-mini": "Grok 3 Mini", - "x-ai/grok-4": "Grok 4", - "x-ai/grok-4-fast": "Grok 4 Fast", - "x-ai/grok-4.1-fast": "Grok 4.1 Fast", + "anthropic/claude-haiku-4.5": "Claude Haiku 4.5", + "anthropic/claude-opus-4.6": "Claude Opus 4.6", + "anthropic/claude-sonnet-4.6": "Claude Sonnet 4.6", + "deepseek/deepseek-r1-0528": "DeepSeek R1 (0528)", + "deepseek/deepseek-v3.2": "DeepSeek v3.2", + "google/gemini-2.5-flash-lite": "Gemini 2.5 Flash Lite", + "google/gemini-2.5-pro": "Gemini 2.5 Pro", + "google/gemini-3-flash-preview": "Gemini 3 Flash", + "google/gemma-2-27b-it": "Gemma 2 27B", + "meta-llama/llama-4-maverick": "Llama 4 Maverick", + "minimax/minimax-m2.7": "MiniMax M2.7", + "mistralai/devstral-2512": "Devstral 2", + "mistralai/mistral-small-2603": "Mistral Small 4", + "moonshotai/kimi-k2.5": "Kimi K2.5", + "openai/gpt-5-mini": "GPT-5 mini", + "openai/gpt-5.2": "GPT-5.2", + "openai/gpt-5.3-codex": "GPT-5.3 Codex", + "openai/gpt-5.4": "GPT-5.4", + "openai/gpt-5.4-mini": "GPT-5.4 Mini", + "openai/o3": "o3", + "openai/o4-mini": "o4-mini", + "qwen/qwen2.5-vl-32b-instruct": "Qwen 2.5 32B", + "qwen/qwen3-coder-next": "Qwen 3 Coder Next", + "qwen/qwen3.5-flash-02-23": "Qwen 3.5 Flash", + "qwen/qwen3.5-plus-02-15": "Qwen 3.5 Plus", + "x-ai/grok-4.1-fast": "Grok 4.1 Fast", + "x-ai/grok-4.20-beta": "Grok 4.20 Beta", + "x-ai/grok-code-fast-1": "Grok Code Fast 1", + "z-ai/glm-5-turbo": "GLM 5 Turbo", }, Aliases: map[string]string{ // Default alias "beeper/default": "anthropic/claude-opus-4.6", // Stable aliases that can be remapped - "beeper/fast": "openai/gpt-5-mini", - "beeper/smart": "openai/gpt-5.2", - "beeper/reasoning": "openai/gpt-5.2", // Uses reasoning effort parameter + "beeper/fast": "openai/gpt-5.4-mini", + "beeper/smart": "openai/gpt-5.4", + "beeper/reasoning": "openai/o3", }, } @@ -188,15 +128,11 @@ func main() { } func run() error { - token := flag.String("openrouter-token", "", "OpenRouter API token") + token := flag.String("openrouter-token", "", "Optional OpenRouter API token") outputFile := flag.String("output", "bridges/ai/beeper_models_generated.go", "Output Go file") jsonFile := flag.String("json", "pkg/ai/beeper_models.json", "Output JSON file for clients") flag.Parse() - if *token == "" { - return fmt.Errorf("--openrouter-token is required") - } - models, err := fetchOpenRouterModels(*token) if err != nil { return fmt.Errorf("fetching models: %w", err) @@ -219,7 +155,9 @@ func fetchOpenRouterModels(token string) (map[string]OpenRouterModel, error) { if err != nil { return nil, err } - req.Header.Set("Authorization", "Bearer "+token) + if strings.TrimSpace(token) != "" { + req.Header.Set("Authorization", "Bearer "+token) + } client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) diff --git a/config.example.yaml b/config.example.yaml index e5c6a761..3605cb70 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -41,82 +41,104 @@ encryption: # AI Chats-specific options (shared with the embedded example in bridges/ai/integrations_config.go) network: - # Beeper Cloud credentials for automatic login (optional) - beeper: - user_mxid: "" # Owning Matrix user for the built-in Beeper Cloud login. - base_url: "" # Optional. If empty, login uses selected Beeper domain. - token: "" # Beeper Matrix access token + # Connector-specific configuration lives under the `network:` section of the + # main config file. - # Per-provider default models - providers: - beeper: - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - openai: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://api.openai.com/v1 - base_url: "https://api.openai.com/v1" - default_model: "openai/gpt-5.2" - openrouter: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://openrouter.ai/api/v1 - base_url: "https://openrouter.ai/api/v1" - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" + beeper: + user_mxid: "" + base_url: "" + token: "" -# Optional model catalog seeding (OpenClaw-style). -# models: -# mode: "merge" # merge | replace -# providers: -# openai: -# models: -# - id: "gpt-5.2" -# name: "GPT-5.2" -# reasoning: true -# input: ["text", "image"] -# context_window: 128000 -# max_tokens: 8192 + models: + providers: + openai: + api_key: "" + base_url: "https://api.openai.com/v1" + models: [] + openrouter: + api_key: "" + base_url: "https://openrouter.ai/api/v1" + models: [] + magic_proxy: + api_key: "" + base_url: "" + models: [] -# Global settings -default_system_prompt: | - You are a helpful, concise assistant. + default_system_prompt: | + You are a helpful, concise assistant. Ask clarifying questions when needed. Follow the user's intent and be accurate. model_cache_duration: 6h - # External tool providers (search + fetch). Proxy is optional. - tools: - search: - provider: "exa" - fallbacks: [] - exa: - api_key: "" - base_url: "https://api.exa.ai" - type: "auto" - num_results: 5 - include_text: false - text_max_chars: 500 - fetch: - provider: "exa" - fallbacks: ["direct"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - include_text: true - text_max_chars: 5000 - direct: - enabled: true - timeout_seconds: 30 - max_chars: 50000 - max_redirects: 3 + messages: + direct_chat: + history_limit: 20 + group_chat: + history_limit: 50 + queue: + mode: "collect" + debounce_ms: 1000 + cap: 20 + drop: "summarize" + + commands: + owner_allow_from: [] + + tool_approvals: + enabled: true + ttl_seconds: 600 + require_for_mcp: true + require_for_tools: ["message", "cron", "gravatar_set", "create_agent", "fork_agent", "edit_agent", "delete_agent", "modify_room", "sessions_send", "sessions_spawn", "run_internal_command"] + + channels: + matrix: + reply_to_mode: "first" + + session: + scope: "per-sender" + main_key: "main" - # Media understanding/transcription (OpenClaw-style). + tools: + web: + search: + provider: "exa" + fallbacks: [] + exa: + api_key: "" + base_url: "https://api.exa.ai" + type: "auto" + num_results: 5 + include_text: false + text_max_chars: 500 + highlights: true + fetch: + provider: "direct" + fallbacks: ["exa"] + exa: + api_key: "" + base_url: "https://api.exa.ai" + include_text: true + text_max_chars: 5000 + direct: + enabled: true + timeout_seconds: 30 + max_chars: 50000 + max_redirects: 3 + links: + enabled: true + max_urls_inbound: 3 + max_urls_outbound: 5 + fetch_timeout: 10s + max_content_chars: 500 + max_page_bytes: 10485760 + max_image_bytes: 5242880 + cache_ttl: 1h + mcp: + enable_stdio: false + vfs: + apply_patch: + enabled: false + allow_models: [] media: concurrency: 2 image: @@ -132,7 +154,6 @@ default_system_prompt: | enabled: true prompt: "Transcribe the audio." language: "" - # CLI transcription auto-detection (whisper/whisper.cpp) is not implemented yet. max_bytes: 20971520 timeout_seconds: 60 models: @@ -147,161 +168,55 @@ default_system_prompt: | - provider: "openrouter" model: "google/gemini-3-flash-preview" - # Memory search configuration (OpenClaw-style). - # Indexes MEMORY.md + memory/*.md stored in the bridge DB. - # Per-agent overrides can be set via agent definitions. - # Current runtime behavior is lexical-only. - memory_search: - enabled: true - sources: ["memory"] - extra_paths: [] - local: - model_path: "" - model_cache_dir: "" - base_url: "" - api_key: "" - store: - driver: "sqlite" - path: "" - chunking: - tokens: 400 - overlap: 80 - sync: - on_session_start: true - on_search: true - watch: true - watch_debounce_ms: 1500 - interval_minutes: 0 - sessions: - delta_bytes: 100000 - delta_messages: 50 - query: - max_results: 6 - min_score: 0.35 - hybrid: - candidate_multiplier: 4 - cache: - enabled: true - max_entries: -1 # Unlimited when cache.enabled is true; experimental.session_memory does not change cache size semantics. - experimental: - session_memory: false - - # Tool policy (OpenClaw-style). Controls allow/deny lists and profiles. - # tool_policy: - # profile: "full" - # # group:openclaw is the strict OpenClaw native tool set. - # # group:ai-bridge includes ai-bridge-only extras (beeper_docs, gravatar_*, tts, image_generate, calculator, etc). - # allow: ["group:openclaw", "group:ai-bridge"] - # deny: [] - # subagents: - # tools: - # deny: ["sessions_list", "sessions_history", "sessions_send"] - - # Agent defaults (OpenClaw-style). - # agents: - # defaults: - # subagents: - # model: "anthropic/claude-sonnet-4.5" - # allow_agents: ["*"] - # typing_mode: "instant" # never|instant|thinking|message (message ignores NO_REPLY) - # typing_interval_seconds: 6 # refresh cadence, not start time (heartbeats never show typing) - - # Context pruning configuration (OpenClaw-style). - # Reduces token usage by intelligently truncating old tool results. - pruning: - # Enable proactive context pruning - enabled: true - - # Ratio of context window usage that triggers soft trimming (0.0-1.0) - # At 30% usage, large tool results start getting truncated - soft_trim_ratio: 0.3 - - # Ratio of context window usage that triggers hard clearing (0.0-1.0) - # At 50% usage, old tool results are replaced with placeholder - hard_clear_ratio: 0.5 - - # Number of recent assistant messages to protect from pruning - keep_last_assistants: 3 - - # Minimum total chars in prunable tool results before hard clear kicks in - min_prunable_chars: 50000 - - # Tool results larger than this are candidates for soft trimming - soft_trim_max_chars: 4000 - - # When soft trimming, keep this many chars from the start - soft_trim_head_chars: 1500 - - # When soft trimming, keep this many chars from the end - soft_trim_tail_chars: 1500 - - # Enable/disable hard clear phase - hard_clear_enabled: true - - # Placeholder text for hard-cleared tool results - hard_clear_placeholder: "[Old tool result content cleared]" - - # Tool patterns to allow/deny pruning (supports wildcards: exec*, *_search) - # Empty means all tools are prunable unless denied - # tools_allow: [] - # tools_deny: [] - - # --- LLM-based summarization (compaction) --- - # When enabled, uses an LLM to generate intelligent summaries of compacted - # content instead of just using placeholder text. This preserves context better. - - # Enable LLM summarization (default: true when pruning is enabled) - summarization_enabled: true - - # Model to use for generating summaries (default: fast model) - summarization_model: "openai/gpt-5.2" - - # Maximum tokens for generated summaries - max_summary_tokens: 500 - - # Maximum ratio of context that history can consume (0.0-1.0) - # When exceeded, oldest messages are summarized to fit budget - max_history_share: 0.5 - - # Token budget reserved for compaction output - reserve_tokens: 20000 - - # Additional instructions for the summarization model - # custom_instructions: "Focus on preserving code decisions and TODOs" - - # Identifier preservation policy for summaries: - # - strict (default): preserve opaque identifiers exactly - # - off: no special identifier-preservation instruction - # - custom: use identifier_instructions below - identifier_policy: "strict" - # identifier_instructions: "Keep ticket IDs, hashes, and hostnames unchanged." - - # Optional pre-compaction overflow flush turn. - # Disabled by default to keep compaction independent from memory workflows. - overflow_flush: - enabled: false - soft_threshold_tokens: 4000 - prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." - system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." - - # Link preview configuration. - # Automatically fetches metadata for URLs in messages to provide context to the AI - # and generate rich previews in outgoing AI responses. - link_previews: - # Enable link preview functionality (default: true) - enabled: true - - # Maximum number of URLs to fetch from user messages for AI context (default: 3) - max_urls_inbound: 3 - - # Maximum number of URLs to preview in AI responses (default: 5) - max_urls_outbound: 5 - - # Timeout for fetching each URL (default: 10s) - fetch_timeout: 10s - - # Maximum characters from description to include in context (default: 500) - max_content_chars: 500 - - # Maximum page size to download in bytes (default: 10MB) - max_page_bytes: 10485760 + agents: + defaults: + model: + primary: "" + fallbacks: [] + image_model: + primary: "" + fallbacks: [] + image_generation_model: + primary: "" + fallbacks: [] + pdf_model: + primary: "" + fallbacks: [] + pdf_engine: "mistral-ocr" + compaction: + # cache-ttl keeps a cached compacted prompt for ttl; other modes can force + # more aggressive recomputation or fallback behavior in runtime defaults. + mode: "cache-ttl" + # ttl applies to cached compaction snapshots when mode uses cache expiry. + ttl: "1h" + # enabled gates soft trimming, hard clear, summarization, and token reserves. + enabled: true + soft_trim_ratio: 0.3 + hard_clear_ratio: 0.5 + # keep a few recent assistant turns and require enough prunable text before trimming. + keep_last_assistants: 3 + min_prunable_chars: 50000 + # soft trim keeps head/tail context from oversized tool results before full compaction. + soft_trim_max_chars: 4000 + soft_trim_head_chars: 1500 + soft_trim_tail_chars: 1500 + hard_clear_enabled: true + hard_clear_placeholder: "[Old tool result content cleared]" + # summarization condenses old history before hard clearing it entirely. + summarization_enabled: true + summarization_model: "openai/gpt-5.2" + max_summary_tokens: 500 + # safeguard preserves recent history/tokens; alternative modes may trade fidelity for space. + compaction_mode: "safeguard" + keep_recent_tokens: 20000 + max_history_share: 0.5 + reserve_tokens: 20000 + reserve_tokens_floor: 20000 + identifier_policy: "strict" + post_compaction_refresh_prompt: "[Post-compaction context refresh]\nRe-anchor to the latest user intent and preserve unresolved tasks and identifiers." + overflow_flush: + # overflow flush runs a last tool-only pass near the soft threshold before compaction. + enabled: true + soft_threshold_tokens: 4000 + prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." + system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." diff --git a/connector_builder.go b/connector_builder.go index f877d115..5e6920fb 100644 --- a/connector_builder.go +++ b/connector_builder.go @@ -14,8 +14,8 @@ type ConnectorSpec struct { AIRoomKind string Init func(*bridgev2.Bridge) - Start func(context.Context) error - Stop func(context.Context) + Start func(context.Context, *bridgev2.Bridge) error + Stop func(context.Context, *bridgev2.Bridge) Name func() bridgev2.BridgeName Config func() (example string, data any, upgrader configupgrade.Upgrader) @@ -62,14 +62,14 @@ func (c *ConnectorBase) Start(ctx context.Context) error { if c == nil || c.spec.Start == nil { return nil } - return c.spec.Start(ctx) + return c.spec.Start(ctx, c.br) } func (c *ConnectorBase) Stop(ctx context.Context) { if c == nil || c.spec.Stop == nil { return } - c.spec.Stop(ctx) + c.spec.Stop(ctx, c.br) } func (c *ConnectorBase) GetName() bridgev2.BridgeName { diff --git a/connector_builder_test.go b/connector_builder_test.go index 0c4d449d..9b751f84 100644 --- a/connector_builder_test.go +++ b/connector_builder_test.go @@ -14,15 +14,29 @@ import ( func TestConnectorBaseHookOrder(t *testing.T) { var order []string + wantBridge := &bridgev2.Bridge{} conn := NewConnector(ConnectorSpec{ - Init: func(*bridgev2.Bridge) { order = append(order, "init") }, - Start: func(context.Context) error { + Init: func(got *bridgev2.Bridge) { + if got != wantBridge { + t.Fatalf("expected init hook bridge %p, got %p", wantBridge, got) + } + order = append(order, "init") + }, + Start: func(_ context.Context, got *bridgev2.Bridge) error { + if got != wantBridge { + t.Fatalf("expected start hook bridge %p, got %p", wantBridge, got) + } order = append(order, "start") return nil }, - Stop: func(context.Context) { order = append(order, "stop") }, + Stop: func(_ context.Context, got *bridgev2.Bridge) { + if got != wantBridge { + t.Fatalf("expected stop hook bridge %p, got %p", wantBridge, got) + } + order = append(order, "stop") + }, }) - conn.Init(nil) + conn.Init(wantBridge) if err := conn.Start(context.Background()); err != nil { t.Fatalf("start returned error: %v", err) } @@ -149,7 +163,7 @@ func TestConnectorStopCanDisconnectCachedClients(t *testing.T) { "b": &fakeClient{}, } conn := NewConnector(ConnectorSpec{ - Stop: func(context.Context) { + Stop: func(context.Context, *bridgev2.Bridge) { StopClients(&mu, &clients) }, }) diff --git a/generate-models.sh b/generate-models.sh index 9503cc5e..c1400be4 100755 --- a/generate-models.sh +++ b/generate-models.sh @@ -1,6 +1,6 @@ #!/bin/bash # Generate AI models Go file from OpenRouter API -# Usage: ./generate-models.sh --openrouter-token="YOUR_TOKEN" +# Usage: ./generate-models.sh [--openrouter-token="YOUR_TOKEN"] # # This script fetches model capabilities from OpenRouter and generates # a Go file with model definitions. The generated file is checked into @@ -24,10 +24,10 @@ while [[ $# -gt 0 ]]; do shift ;; -h|--help) - echo "Usage: $0 --openrouter-token=TOKEN [--output=FILE]" + echo "Usage: $0 [--openrouter-token=TOKEN] [--output=FILE]" echo "" echo "Options:" - echo " --openrouter-token=TOKEN OpenRouter API token (required)" + echo " --openrouter-token=TOKEN Optional OpenRouter API token" echo " --output=FILE Output file path (default: bridges/ai/beeper_models_generated.go)" echo " --json=FILE Output JSON path (default: pkg/ai/beeper_models.json)" exit 0 @@ -43,12 +43,6 @@ while [[ $# -gt 0 ]]; do esac done -if [ -z "$OPENROUTER_TOKEN" ]; then - echo "Error: --openrouter-token is required" - echo "Usage: $0 --openrouter-token=TOKEN" - exit 1 -fi - # Change to script directory cd "$(dirname "$0")" diff --git a/helpers.go b/helpers.go index df2c7362..e403e3c2 100644 --- a/helpers.go +++ b/helpers.go @@ -450,23 +450,21 @@ func ApplyAgentRemoteBridgeInfo(content *event.BridgeEventContent, protocolID st } func SendAIRoomInfo(ctx context.Context, portal *bridgev2.Portal, aiKind string) bool { - if portal == nil || portal.MXID == "" { + if portal == nil || portal.MXID == "" || portal.Bridge == nil || portal.Bridge.Bot == nil { return false } if aiKind == "" { aiKind = AIRoomKindAgent } - //lint:ignore SA1019 bridgev2 currently exposes room-meta sending via portal internals - return portal.Internal().SendRoomMeta( - ctx, - nil, - time.Now(), - matrixevents.AIRoomInfoEventType, - "", - map[string]any{"type": aiKind}, - true, - nil, - ) + _, err := portal.Bridge.Bot.SendState(ctx, portal.MXID, matrixevents.AIRoomInfoEventType, "", &event.Content{ + Parsed: map[string]any{"type": aiKind}, + Raw: map[string]any{"com.beeper.exclude_from_timeline": true}, + }, time.Now()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to send AI room info state event") + return false + } + return true } // findExistingMessage performs a two-phase message lookup: first by network diff --git a/pkg/agents/beeper.go b/pkg/agents/beeper.go index 13253a5b..1e2a7e2b 100644 --- a/pkg/agents/beeper.go +++ b/pkg/agents/beeper.go @@ -19,7 +19,7 @@ var BeeperAIAgent = &AgentDefinition{ Fallbacks: []string{ ModelClaudeSonnet, ModelOpenAIGPT52, - ModelZAIGLM47, + ModelZAIGLM5Turbo, }, }, Tools: &toolpolicy.ToolPolicyConfig{Profile: toolpolicy.ProfileFull}, diff --git a/pkg/agents/boss.go b/pkg/agents/boss.go index deb1118e..d47c0b69 100644 --- a/pkg/agents/boss.go +++ b/pkg/agents/boss.go @@ -13,7 +13,7 @@ var BossAgent = &AgentDefinition{ Fallbacks: []string{ ModelClaudeSonnet, ModelOpenAIGPT52, - ModelZAIGLM47, + ModelZAIGLM5Turbo, }, }, Tools: &toolpolicy.ToolPolicyConfig{Profile: toolpolicy.ProfileBoss}, diff --git a/pkg/agents/presets.go b/pkg/agents/presets.go index 804074f9..b16ad307 100644 --- a/pkg/agents/presets.go +++ b/pkg/agents/presets.go @@ -4,10 +4,10 @@ import "slices" // Model constants for preset agents (aligned with clawdbot recommended models). const ( - ModelClaudeSonnet = "anthropic/claude-sonnet-4.5" + ModelClaudeSonnet = "anthropic/claude-sonnet-4.6" ModelClaudeOpus = "anthropic/claude-opus-4.6" ModelOpenAIGPT52 = "openai/gpt-5.2" - ModelZAIGLM47 = "z-ai/glm-4.7" + ModelZAIGLM5Turbo = "z-ai/glm-5-turbo" ) // PresetAgents contains the default agent definitions: diff --git a/pkg/ai/beeper_models.json b/pkg/ai/beeper_models.json index 57904ff2..4cd1b0f5 100644 --- a/pkg/ai/beeper_models.json +++ b/pkg/ai/beeper_models.json @@ -18,8 +18,8 @@ ] }, { - "id": "anthropic/claude-opus-4.1", - "name": "Claude 4.1 Opus", + "id": "anthropic/claude-opus-4.6", + "name": "Claude Opus 4.6", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, @@ -27,16 +27,16 @@ "supports_reasoning": true, "supports_web_search": true, "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 32000, + "context_window": 1000000, + "max_output_tokens": 128000, "available_tools": [ "web_search", "function_calling" ] }, { - "id": "anthropic/claude-opus-4.5", - "name": "Claude Opus 4.5", + "id": "anthropic/claude-sonnet-4.6", + "name": "Claude Sonnet 4.6", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, @@ -44,1046 +44,408 @@ "supports_reasoning": true, "supports_web_search": true, "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 64000, + "context_window": 1000000, + "max_output_tokens": 128000, "available_tools": [ "web_search", "function_calling" ] }, { - "id": "anthropic/claude-opus-4.6", - "name": "Claude Opus 4.6", + "id": "deepseek/deepseek-r1-0528", + "name": "DeepSeek R1 (0528)", "provider": "openrouter", "api": "openai-completions", - "supports_vision": true, + "supports_vision": false, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1000000, - "max_output_tokens": 128000, + "supports_web_search": false, + "context_window": 163840, + "max_output_tokens": 65536, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "anthropic/claude-sonnet-4", - "name": "Claude 4 Sonnet", + "id": "deepseek/deepseek-v3.2", + "name": "DeepSeek v3.2", "provider": "openrouter", "api": "openai-completions", - "supports_vision": true, + "supports_vision": false, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 64000, + "supports_web_search": false, + "context_window": 163840, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "anthropic/claude-sonnet-4.5", - "name": "Claude Sonnet 4.5", + "id": "google/gemini-2.5-flash-lite", + "name": "Gemini 2.5 Flash Lite", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": true, + "supports_web_search": false, + "supports_audio": true, + "supports_video": true, "supports_pdf": true, - "context_window": 1000000, - "max_output_tokens": 64000, + "context_window": 1048576, + "max_output_tokens": 65535, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "anthropic/claude-sonnet-4.6", - "name": "Claude Sonnet 4.6", + "id": "google/gemini-2.5-pro", + "name": "Gemini 2.5 Pro", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": true, + "supports_web_search": false, + "supports_audio": true, + "supports_video": true, "supports_pdf": true, - "context_window": 1000000, - "max_output_tokens": 128000, + "context_window": 1048576, + "max_output_tokens": 65536, "available_tools": [ - "web_search", "function_calling" ] }, { - "id": "deepseek/deepseek-chat-v3-0324", - "name": "DeepSeek v3 (0324)", + "id": "google/gemini-3-flash-preview", + "name": "Gemini 3 Flash", "provider": "openrouter", "api": "openai-completions", - "supports_vision": false, + "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": false, - "context_window": 163840, - "max_output_tokens": 163840, + "supports_audio": true, + "supports_video": true, + "supports_pdf": true, + "context_window": 1048576, + "max_output_tokens": 65536, "available_tools": [ "function_calling" ] }, { - "id": "deepseek/deepseek-chat-v3.1", - "name": "DeepSeek v3.1", + "id": "google/gemma-2-27b-it", + "name": "Gemma 2 27B", "provider": "openrouter", "api": "openai-completions", "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, + "supports_tool_calling": false, + "supports_reasoning": false, "supports_web_search": false, - "context_window": 32768, - "max_output_tokens": 7168, - "available_tools": [ - "function_calling" - ] + "context_window": 8192, + "max_output_tokens": 2048 }, { - "id": "deepseek/deepseek-r1", - "name": "DeepSeek R1 (Original)", + "id": "meta-llama/llama-4-maverick", + "name": "Llama 4 Maverick", "provider": "openrouter", "api": "openai-completions", - "supports_vision": false, + "supports_vision": true, "supports_tool_calling": true, - "supports_reasoning": true, + "supports_reasoning": false, "supports_web_search": false, - "context_window": 64000, - "max_output_tokens": 16000, + "context_window": 1048576, + "max_output_tokens": 16384, "available_tools": [ "function_calling" ] }, { - "id": "deepseek/deepseek-r1-0528", - "name": "DeepSeek R1 (0528)", + "id": "minimax/minimax-m2.7", + "name": "MiniMax M2.7", "provider": "openrouter", "api": "openai-completions", "supports_vision": false, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": false, - "context_window": 163840, - "max_output_tokens": 65536, + "context_window": 204800, + "max_output_tokens": 131072, "available_tools": [ "function_calling" ] }, { - "id": "deepseek/deepseek-r1-distill-qwen-32b", - "name": "DeepSeek R1 (Qwen Distilled)", + "id": "mistralai/devstral-2512", + "name": "Devstral 2", "provider": "openrouter", "api": "openai-completions", "supports_vision": false, - "supports_tool_calling": false, - "supports_reasoning": true, + "supports_tool_calling": true, + "supports_reasoning": false, "supports_web_search": false, - "context_window": 32768, - "max_output_tokens": 32768 + "context_window": 262144, + "available_tools": [ + "function_calling" + ] }, { - "id": "deepseek/deepseek-v3.1-terminus", - "name": "DeepSeek v3.1 Terminus", + "id": "mistralai/mistral-small-2603", + "name": "Mistral Small 4", "provider": "openrouter", "api": "openai-completions", - "supports_vision": false, + "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": false, - "context_window": 163840, + "context_window": 262144, "available_tools": [ "function_calling" ] }, { - "id": "deepseek/deepseek-v3.2", - "name": "DeepSeek v3.2", + "id": "moonshotai/kimi-k2.5", + "name": "Kimi K2.5", "provider": "openrouter", "api": "openai-completions", - "supports_vision": false, + "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": false, - "context_window": 163840, - "max_output_tokens": 65536, + "context_window": 262144, + "max_output_tokens": 65535, "available_tools": [ "function_calling" ] }, { - "id": "google/gemini-2.0-flash-001", - "name": "Gemini 2.0 Flash", + "id": "openai/gpt-5-mini", + "name": "GPT-5 mini", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, + "supports_reasoning": true, + "supports_web_search": true, "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 8192, + "context_window": 400000, + "max_output_tokens": 128000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "google/gemini-2.0-flash-lite-001", - "name": "Gemini 2.0 Flash Lite", + "id": "openai/gpt-5.2", + "name": "GPT-5.2", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, + "supports_reasoning": true, + "supports_web_search": true, "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 8192, + "context_window": 400000, + "max_output_tokens": 128000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "google/gemini-2.5-flash", - "name": "Gemini 2.5 Flash", + "id": "openai/gpt-5.3-codex", + "name": "GPT-5.3 Codex", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, + "supports_web_search": true, "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65535, + "context_window": 400000, + "max_output_tokens": 128000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "google/gemini-2.5-flash-image", - "name": "Nano Banana", + "id": "openai/gpt-5.4", + "name": "GPT-5.4", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, - "supports_tool_calling": false, - "supports_reasoning": false, - "supports_web_search": false, - "supports_image_gen": true, + "supports_tool_calling": true, + "supports_reasoning": true, + "supports_web_search": true, "supports_pdf": true, - "context_window": 32768, - "max_output_tokens": 32768 + "context_window": 1050000, + "max_output_tokens": 128000, + "available_tools": [ + "web_search", + "function_calling" + ] }, { - "id": "google/gemini-2.5-flash-lite", - "name": "Gemini 2.5 Flash Lite", + "id": "openai/gpt-5.4-mini", + "name": "GPT-5.4 Mini", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, + "supports_web_search": true, "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65535, + "context_window": 400000, + "max_output_tokens": 128000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "google/gemini-2.5-pro", - "name": "Gemini 2.5 Pro", + "id": "openai/o3", + "name": "o3", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, + "supports_web_search": true, "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65536, + "context_window": 200000, + "max_output_tokens": 100000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "google/gemini-3-flash-preview", - "name": "Gemini 3 Flash", + "id": "openai/o4-mini", + "name": "o4-mini", "provider": "openrouter", - "api": "openai-completions", + "api": "openai-responses", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "supports_audio": true, - "supports_video": true, + "supports_web_search": true, "supports_pdf": true, - "context_window": 1048576, - "max_output_tokens": 65536, + "context_window": 200000, + "max_output_tokens": 100000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "google/gemini-3-pro-image-preview", - "name": "Nano Banana Pro", + "id": "qwen/qwen2.5-vl-32b-instruct", + "name": "Qwen 2.5 32B", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, "supports_tool_calling": false, - "supports_reasoning": true, + "supports_reasoning": false, "supports_web_search": false, - "supports_image_gen": true, - "supports_pdf": true, - "context_window": 65536, - "max_output_tokens": 32768 + "context_window": 128000 }, { - "id": "google/gemini-3.1-flash-lite-preview", - "name": "Gemini 3.1 Flash Lite", + "id": "qwen/qwen3-coder-next", + "name": "Qwen 3 Coder Next", "provider": "openrouter", "api": "openai-completions", - "supports_vision": true, + "supports_vision": false, "supports_tool_calling": true, - "supports_reasoning": true, + "supports_reasoning": false, "supports_web_search": false, - "supports_audio": true, - "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, + "context_window": 262144, "max_output_tokens": 65536, "available_tools": [ "function_calling" ] }, { - "id": "google/gemini-3.1-pro-preview", - "name": "Gemini 3.1 Pro", + "id": "qwen/qwen3.5-flash-02-23", + "name": "Qwen 3.5 Flash", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, "supports_reasoning": true, "supports_web_search": false, - "supports_audio": true, "supports_video": true, - "supports_pdf": true, - "context_window": 1048576, + "context_window": 1000000, "max_output_tokens": 65536, "available_tools": [ "function_calling" ] }, { - "id": "meta-llama/llama-3.3-70b-instruct", - "name": "Llama 3.3 70B", + "id": "qwen/qwen3.5-plus-02-15", + "name": "Qwen 3.5 Plus", "provider": "openrouter", "api": "openai-completions", - "supports_vision": false, + "supports_vision": true, "supports_tool_calling": true, - "supports_reasoning": false, + "supports_reasoning": true, "supports_web_search": false, - "context_window": 131072, - "max_output_tokens": 16384, + "supports_video": true, + "context_window": 1000000, + "max_output_tokens": 65536, "available_tools": [ "function_calling" ] }, { - "id": "meta-llama/llama-4-maverick", - "name": "Llama 4 Maverick", + "id": "x-ai/grok-4.1-fast", + "name": "Grok 4.1 Fast", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 1048576, - "max_output_tokens": 16384, + "supports_reasoning": true, + "supports_web_search": true, + "context_window": 2000000, + "max_output_tokens": 30000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "meta-llama/llama-4-scout", - "name": "Llama 4 Scout", + "id": "x-ai/grok-4.20-beta", + "name": "Grok 4.20 Beta", "provider": "openrouter", "api": "openai-completions", "supports_vision": true, "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 327680, - "max_output_tokens": 16384, + "supports_reasoning": true, + "supports_web_search": true, + "context_window": 2000000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "minimax/minimax-m2", - "name": "MiniMax M2", + "id": "x-ai/grok-code-fast-1", + "name": "Grok Code Fast 1", "provider": "openrouter", "api": "openai-completions", "supports_vision": false, "supports_tool_calling": true, "supports_reasoning": true, - "supports_web_search": false, - "context_window": 196608, - "max_output_tokens": 196608, + "supports_web_search": true, + "context_window": 256000, + "max_output_tokens": 10000, "available_tools": [ + "web_search", "function_calling" ] }, { - "id": "minimax/minimax-m2.1", - "name": "MiniMax M2.1", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 196608, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "minimax/minimax-m2.5", - "name": "MiniMax M2.5", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 196608, - "max_output_tokens": 196608, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "moonshotai/kimi-k2", - "name": "Kimi K2 (0711)", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 131000, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "moonshotai/kimi-k2-0905", - "name": "Kimi K2 (0905)", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 131072, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "moonshotai/kimi-k2.5", - "name": "Kimi K2.5", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 262144, - "max_output_tokens": 65535, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "openai/gpt-4.1", - "name": "GPT-4.1", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1047576, - "max_output_tokens": 32768, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-4.1-mini", - "name": "GPT-4.1 Mini", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1047576, - "max_output_tokens": 32768, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-4.1-nano", - "name": "GPT-4.1 Nano", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1047576, - "max_output_tokens": 32768, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-4o-mini", - "name": "GPT-4o-mini", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 128000, - "max_output_tokens": 16384, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5", - "name": "GPT-5", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5-image", - "name": "GPT ImageGen 1.5", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_image_gen": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5-image-mini", - "name": "GPT ImageGen", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_image_gen": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5-mini", - "name": "GPT-5 mini", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5-nano", - "name": "GPT-5 nano", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5.1", - "name": "GPT-5.1", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5.2", - "name": "GPT-5.2", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5.2-pro", - "name": "GPT-5.2 Pro", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 400000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5.3-chat", - "name": "GPT-5.3 Instant", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 128000, - "max_output_tokens": 16384, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-5.4", - "name": "GPT-5.4", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 1050000, - "max_output_tokens": 128000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/gpt-oss-120b", - "name": "GPT OSS 120B", - "provider": "openrouter", - "api": "responses", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 131072, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "openai/gpt-oss-20b", - "name": "GPT OSS 20B", - "provider": "openrouter", - "api": "responses", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 131072, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "openai/o3", - "name": "o3", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 100000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/o3-mini", - "name": "o3-mini", - "provider": "openrouter", - "api": "responses", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 100000, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "openai/o3-pro", - "name": "o3 Pro", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 100000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "openai/o4-mini", - "name": "o4-mini", - "provider": "openrouter", - "api": "responses", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "supports_pdf": true, - "context_window": 200000, - "max_output_tokens": 100000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "qwen/qwen2.5-vl-32b-instruct", - "name": "Qwen 2.5 32B", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": false, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 128000 - }, - { - "id": "qwen/qwen3-235b-a22b", - "name": "Qwen 3 235B", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 131072, - "max_output_tokens": 8192, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "qwen/qwen3-32b", - "name": "Qwen 3 32B", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 40960, - "max_output_tokens": 40960, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "qwen/qwen3-coder", - "name": "Qwen 3 Coder", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": false, - "context_window": 262144, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "x-ai/grok-3", - "name": "Grok 3", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": false, - "supports_web_search": true, - "context_window": 131072, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "x-ai/grok-3-mini", - "name": "Grok 3 Mini", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "context_window": 131072, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "x-ai/grok-4", - "name": "Grok 4", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "context_window": 256000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "x-ai/grok-4-fast", - "name": "Grok 4 Fast", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "context_window": 2000000, - "max_output_tokens": 30000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "x-ai/grok-4.1-fast", - "name": "Grok 4.1 Fast", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": true, - "context_window": 2000000, - "max_output_tokens": 30000, - "available_tools": [ - "web_search", - "function_calling" - ] - }, - { - "id": "z-ai/glm-4.5", - "name": "GLM 4.5", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 131072, - "max_output_tokens": 98304, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "z-ai/glm-4.5-air", - "name": "GLM 4.5 Air", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 131072, - "max_output_tokens": 98304, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "z-ai/glm-4.5v", - "name": "GLM 4.5V", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 65536, - "max_output_tokens": 16384, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "z-ai/glm-4.6", - "name": "GLM 4.6", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 204800, - "max_output_tokens": 204800, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "z-ai/glm-4.6v", - "name": "GLM 4.6V", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": true, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "supports_video": true, - "context_window": 131072, - "max_output_tokens": 131072, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "z-ai/glm-4.7", - "name": "GLM 4.7", - "provider": "openrouter", - "api": "openai-completions", - "supports_vision": false, - "supports_tool_calling": true, - "supports_reasoning": true, - "supports_web_search": false, - "context_window": 202752, - "available_tools": [ - "function_calling" - ] - }, - { - "id": "z-ai/glm-5", - "name": "GLM 5", + "id": "z-ai/glm-5-turbo", + "name": "GLM 5 Turbo", "provider": "openrouter", "api": "openai-completions", "supports_vision": false, @@ -1091,6 +453,7 @@ "supports_reasoning": true, "supports_web_search": false, "context_window": 202752, + "max_output_tokens": 131072, "available_tools": [ "function_calling" ] @@ -1098,8 +461,8 @@ ], "aliases": { "beeper/default": "anthropic/claude-opus-4.6", - "beeper/fast": "openai/gpt-5-mini", - "beeper/reasoning": "openai/gpt-5.2", - "beeper/smart": "openai/gpt-5.2" + "beeper/fast": "openai/gpt-5.4-mini", + "beeper/reasoning": "openai/o3", + "beeper/smart": "openai/gpt-5.4" } } diff --git a/pkg/connector/integrations_example-config.yaml b/pkg/connector/integrations_example-config.yaml deleted file mode 100644 index be28341e..00000000 --- a/pkg/connector/integrations_example-config.yaml +++ /dev/null @@ -1,364 +0,0 @@ -# Connector-specific configuration lives under the `network:` section of the -# main config file. - -# Beeper Cloud credentials for automatic login (optional). -# If user_mxid, base_url, and token are set, users don't need to manually log in. -beeper: - user_mxid: "" # Owning Matrix user for the built-in Beeper Cloud login. - base_url: "" # Optional. If empty, login uses selected Beeper domain. - token: "" # Beeper Matrix access token - -# Per-provider default models and settings. -# These are used when a room doesn't have a specific model configured. -providers: - beeper: - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - openai: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://api.openai.com/v1 - base_url: "https://api.openai.com/v1" - default_model: "openai/gpt-5.2" - openrouter: - # Optional. If set, overrides login-provided key. - api_key: "" - # Optional. Defaults to https://openrouter.ai/api/v1 - base_url: "https://openrouter.ai/api/v1" - default_model: "anthropic/claude-opus-4.6" - # PDF processing engine for OpenRouter's file-parser plugin. - # Options: pdf-text (free), mistral-ocr (OCR, paid, default), native - default_pdf_engine: "mistral-ocr" - -# Optional model catalog seeding. -# models: -# mode: "merge" # merge | replace -# providers: -# openai: -# models: -# - id: "gpt-5.2" -# name: "GPT-5.2" -# reasoning: true -# input: ["text", "image"] -# context_window: 128000 -# max_tokens: 8192 - -# Global settings -default_system_prompt: | - You are a helpful, concise assistant. - Ask clarifying questions when needed. - Follow the user's intent and be accurate. -model_cache_duration: 6h - -# Optional message rendering settings. -messages: - # History defaults for prompt construction. - # Set 0 to disable. - direct_chat: - history_limit: 20 - group_chat: - history_limit: 50 - # Queue behavior while the agent is busy. - queue: - # Modes: collect, followup, steer, steer-backlog, interrupt - mode: "collect" - # Debounce time before draining queued messages (ms). - debounce_ms: 1000 - # Maximum queued messages before drop policy applies. - cap: 20 - # Drop policy when cap is exceeded: summarize, old, new - drop: "summarize" - -# Command authorization settings. -commands: - # Optional allowlist for owner-only tools/commands (Matrix IDs, or "matrix:@user:server"). - owner_allow_from: [] - -# Tool approval gating. -tool_approvals: - enabled: true - ttl_seconds: 600 - require_for_mcp: true - # List of builtin tool names that require approval (subject to per-tool action allowlists). - # Note: `message` approvals apply to Desktop API routing too (e.g. action=send/reply/edit with desktop chat hints), - # while Desktop read-only actions like desktop-search-* do not require approval. - require_for_tools: ["message", "cron", "gravatar_set", "create_agent", "fork_agent", "edit_agent", "delete_agent", "modify_room", "sessions_send", "sessions_spawn", "run_internal_command"] - # Fallback when approval times out: "deny" (default) | "allow". - # Set to "allow" for cron/automated contexts where no human can respond. - -# Optional per-channel overrides. -channels: - matrix: - # Matrix reply/thread behavior. - reply_to_mode: "first" - -# Session configuration. -session: - # Scope for session state: per-sender (default) or global. - scope: "per-sender" - # Main session key alias (default: "main"). - main_key: "main" - -# External tool providers (search + fetch). Proxy is optional. -tools: - search: - provider: "openrouter" - fallbacks: ["exa", "brave", "perplexity"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - type: "auto" - num_results: 5 - include_text: false - text_max_chars: 500 - highlights: true # enabled by default; provides description snippets for source cards - brave: - api_key: "" - base_url: "https://api.search.brave.com/res/v1/web/search" - perplexity: - api_key: "" - base_url: "https://openrouter.ai/api/v1" - model: "perplexity/sonar-pro" - openrouter: - api_key: "" - base_url: "https://openrouter.ai/api/v1" - model: "openai/gpt-5.2" - fetch: - provider: "exa" - fallbacks: ["direct"] - exa: - api_key: "" - base_url: "https://api.exa.ai" - include_text: true - text_max_chars: 5000 - direct: - enabled: true - timeout_seconds: 30 - max_chars: 50000 - max_redirects: 3 - - # Generic MCP behavior. - mcp: - # Disabled by default for safety. Enable explicitly to allow local stdio MCP servers. - enable_stdio: false - - # Virtual filesystem tools. - vfs: - apply_patch: - enabled: false - allow_models: [] - - # Media understanding/transcription. - # Supports provider/CLI entries and per-capability defaults. - media: - concurrency: 2 - image: - enabled: true - prompt: "Describe the image." - max_bytes: 10485760 - max_chars: 500 - timeout_seconds: 60 - models: - - provider: "openrouter" - model: "google/gemini-3-flash-preview" - audio: - enabled: true - prompt: "Transcribe the audio." - language: "" - # CLI transcription auto-detection (whisper/whisper.cpp) is not implemented yet. - max_bytes: 20971520 - timeout_seconds: 60 - models: - - provider: "openai" - model: "gpt-4o-mini-transcribe" - video: - enabled: true - prompt: "Describe the video." - max_bytes: 52428800 - timeout_seconds: 120 - models: - - provider: "openrouter" - model: "google/gemini-3-flash-preview" - chunking: - tokens: 400 - overlap: 80 - sync: - on_session_start: true - on_search: true - watch: true - watch_debounce_ms: 1500 - interval_minutes: 0 - sessions: - delta_bytes: 100000 - delta_messages: 50 - query: - max_results: 6 - min_score: 0.35 - hybrid: - candidate_multiplier: 4 - cache: - enabled: true - max_entries: -1 # Unlimited when cache.enabled is true; experimental.session_memory does not change cache size semantics. - experimental: - session_memory: false - -# Recall configuration. -# recall: -# citations: "auto" # auto | on | off -# inject_context: false # default false. when true, injects MEMORY.md snippets as extra system context. - - # Tool policy. Controls allow/deny lists and profiles. - # tool_policy: - # profile: "full" - # # group:openclaw is the strict OpenClaw native tool set. - # # group:agentremote includes agentremote-only extras (beeper_docs, gravatar_*, tts, image_generate, calculator, etc). - # allow: ["group:openclaw", "group:agentremote"] - # deny: [] - # subagents: - # tools: - # deny: ["sessions_list", "sessions_history", "sessions_send"] - - # Agent defaults. - # agents: - # defaults: - # subagents: - # model: "anthropic/claude-sonnet-4.5" - # allow_agents: ["*"] - # skip_bootstrap: false - # bootstrap_max_chars: 20000 - # typing_mode: "instant" # never|instant|thinking|message (message ignores NO_REPLY) - # typing_interval_seconds: 6 # refresh cadence, not start time (heartbeats never show typing) - # soul_evil: - # file: "SOUL_EVIL.md" - # chance: 0.1 - # purge: - # at: "21:00" - # duration: "15m" - -# Context pruning configuration. -# Reduces token usage by intelligently truncating old tool results. -pruning: - # Pruning mode: off | cache-ttl - # cache-ttl is the default pruning mode. - mode: "cache-ttl" - - # Refresh interval for cache-ttl mode. - ttl: "1h" - - # Enable proactive context pruning - enabled: true - - # Ratio of context window usage that triggers soft trimming (0.0-1.0) - # At 30% usage, large tool results start getting truncated - soft_trim_ratio: 0.3 - - # Ratio of context window usage that triggers hard clearing (0.0-1.0) - # At 50% usage, old tool results are replaced with placeholder - hard_clear_ratio: 0.5 - - # Number of recent assistant messages to protect from pruning - keep_last_assistants: 3 - - # Minimum total chars in prunable tool results before hard clear kicks in - min_prunable_chars: 50000 - - # Tool results larger than this are candidates for soft trimming - soft_trim_max_chars: 4000 - - # When soft trimming, keep this many chars from the start - soft_trim_head_chars: 1500 - - # When soft trimming, keep this many chars from the end - soft_trim_tail_chars: 1500 - - # Enable/disable hard clear phase - hard_clear_enabled: true - - # Placeholder text for hard-cleared tool results - hard_clear_placeholder: "[Old tool result content cleared]" - - # Tool patterns to allow/deny pruning (supports wildcards: list_*, *_search) - # Empty means all tools are prunable unless denied - # tools_allow: [] - # tools_deny: [] - - # --- LLM-based summarization (compaction) --- - # When enabled, uses an LLM to generate intelligent summaries of compacted - # content instead of just using placeholder text. This preserves context better. - - # Enable LLM summarization (default: true when pruning is enabled) - summarization_enabled: true - - # Model to use for generating summaries (default: fast model) - summarization_model: "openai/gpt-5.2" - - # Maximum tokens for generated summaries - max_summary_tokens: 500 - - # Compaction mode: - # - default: balanced reduction - # - safeguard: preserves recent context more aggressively - compaction_mode: "safeguard" - - # Minimum recent token budget preserved during safeguard compaction - keep_recent_tokens: 20000 - - # Maximum ratio of context that history can consume (0.0-1.0) - # When exceeded, oldest messages are summarized to fit budget - max_history_share: 0.5 - - # Token budget reserved for compaction output - reserve_tokens: 20000 - # Floor applied to reserve_tokens to avoid aggressive overfill - reserve_tokens_floor: 20000 - - # Optional post-compaction system context injected before retry - post_compaction_refresh_prompt: "[Post-compaction context refresh]\nRe-anchor to the latest user intent and preserve unresolved tasks and identifiers." - - # Additional instructions for the summarization model - # custom_instructions: "Focus on preserving code decisions and TODOs" - - # Identifier preservation policy for summaries: - # - strict (default): preserve opaque identifiers exactly - # - off: no special identifier-preservation instruction - # - custom: use identifier_instructions below - identifier_policy: "strict" - # identifier_instructions: "Keep ticket IDs, hashes, and hostnames unchanged." - - # Optional pre-compaction overflow flush turn. - # Enabled by default. Disable explicitly if you want no pre-flush. - overflow_flush: - enabled: true - soft_threshold_tokens: 4000 - prompt: "Pre-compaction overflow flush. Persist any durable notes now if your tools support it. If nothing to store, reply with NO_REPLY." - system_prompt: "Pre-compaction overflow flush turn. The session is near auto-compaction; persist durable notes if possible. You may reply, but usually NO_REPLY is correct." - -# Link preview configuration. -# Automatically fetches metadata for URLs in messages to provide context to the AI -# and generate rich previews in outgoing AI responses. -link_previews: - # Enable link preview functionality (default: true) - enabled: true - - # Maximum number of URLs to fetch from user messages for AI context (default: 3) - max_urls_inbound: 3 - - # Maximum number of URLs to preview in AI responses (default: 5) - max_urls_outbound: 5 - - # Timeout for fetching each URL (default: 10s) - fetch_timeout: 10s - - # Maximum characters from description to include in context (default: 500) - max_content_chars: 500 - - # Maximum page size to download in bytes (default: 10MB) - max_page_bytes: 10485760 - - # Maximum image size to download in bytes (default: 5MB) - max_image_bytes: 5242880 - - # How long to cache URL previews (default: 1h) - cache_ttl: 1h diff --git a/pkg/integrations/cron/integration.go b/pkg/integrations/cron/integration.go index 74081b73..8cf3ae16 100644 --- a/pkg/integrations/cron/integration.go +++ b/pkg/integrations/cron/integration.go @@ -138,7 +138,7 @@ func (i *Integration) executeCronCommand(ctx context.Context, call iruntime.Comm reply("Cron add failed: %s", err.Error()) return nil } - deps := i.buildToolExecDeps(ctx, commandScopeToToolScope(call.Scope)) + deps := i.buildToolExecDeps(ctx, iruntime.ToolScope(call.Scope)) injectToolContext(&input, deps.ResolveCreateContext) if input.Delivery != nil && strings.EqualFold(strings.TrimSpace(string(input.Delivery.Mode)), "announce") && deps.ValidateDeliveryTo != nil { if err := deps.ValidateDeliveryTo(input.Delivery.To); err != nil { @@ -169,7 +169,7 @@ func (i *Integration) executeCronCommand(ctx context.Context, call iruntime.Comm reply("Cron update failed: %s", err.Error()) return nil } - deps := i.buildToolExecDeps(ctx, commandScopeToToolScope(call.Scope)) + deps := i.buildToolExecDeps(ctx, iruntime.ToolScope(call.Scope)) if patch.Delivery != nil && patch.Delivery.To != nil && deps.ValidateDeliveryTo != nil { if err := deps.ValidateDeliveryTo(*patch.Delivery.To); err != nil { reply("Cron update failed: %s", err.Error()) @@ -228,7 +228,7 @@ func (i *Integration) buildToolExecDeps(ctx context.Context, scope iruntime.Tool ResolveCreateContext: func() ToolCreateContext { agentID := i.host.DefaultAgentID() if scope.Meta != nil { - if resolved := strings.TrimSpace(i.host.AgentIDFromMeta(scope.Meta)); resolved != "" { + if resolved := strings.TrimSpace(scope.Meta.AgentID()); resolved != "" { agentID = resolved } } @@ -238,7 +238,7 @@ func (i *Integration) buildToolExecDeps(ctx context.Context, scope iruntime.Tool } sourceInternal := false if scope.Meta != nil { - sourceInternal = i.host.IsInternalRoom(scope.Meta) + sourceInternal = scope.Meta.InternalRoom() } return ToolCreateContext{AgentID: agentID, SourceInternal: sourceInternal, SourceRoomID: roomID} }, @@ -279,14 +279,6 @@ func (i *Integration) buildToolExecDeps(ctx context.Context, scope iruntime.Tool return deps } -func commandScopeToToolScope(scope iruntime.CommandScope) iruntime.ToolScope { - return iruntime.ToolScope{ - Client: scope.Client, - Portal: scope.Portal, - Meta: scope.Meta, - } -} - var ( _ iruntime.ToolIntegration = (*Integration)(nil) _ iruntime.CommandIntegration = (*Integration)(nil) diff --git a/pkg/integrations/memory/integration.go b/pkg/integrations/memory/integration.go index 7ba5fb35..31281db3 100644 --- a/pkg/integrations/memory/integration.go +++ b/pkg/integrations/memory/integration.go @@ -9,6 +9,7 @@ import ( "github.com/openai/openai-go/v3" "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/bridgev2" "github.com/beeper/agentremote/pkg/agents" iruntime "github.com/beeper/agentremote/pkg/integrations/runtime" @@ -26,8 +27,8 @@ type ProviderStatus = memorycore.ProviderStatus type ResolvedConfig = memorycore.ResolvedConfig // Integration is the self-owned memory integration module. -// It implements ToolIntegration, PromptIntegration, CommandIntegration, -// EventIntegration, LoginPurgeIntegration, and LoginLifecycleIntegration +// It implements ToolIntegration, CommandIntegration, EventIntegration, +// LoginPurgeIntegration, and LoginLifecycleIntegration // directly, wiring all deps from Host // capability interfaces. type Integration struct { @@ -69,7 +70,7 @@ func (i *Integration) ToolAvailability(_ context.Context, scope iruntime.ToolSco return false, false, iruntime.SourceGlobalDefault, "" } if scope.Meta != nil { - agentID := i.host.AgentIDFromMeta(scope.Meta) + agentID := i.agentIDFromEventMeta(scope.Meta) _, errMsg := i.getManager(agentID) if errMsg != "" { return true, false, iruntime.SourceProviderLimit, errMsg @@ -78,12 +79,8 @@ func (i *Integration) ToolAvailability(_ context.Context, scope iruntime.ToolSco return true, true, iruntime.SourceGlobalDefault, "" } -func (i *Integration) AdditionalSystemMessages(_ context.Context, _ iruntime.PromptScope) []openai.ChatCompletionMessageParamUnion { - return nil -} - -func (i *Integration) AugmentPrompt(ctx context.Context, scope iruntime.PromptScope, prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion { - return AugmentPrompt(ctx, scope, prompt, PromptAugmentDeps{ +func (i *Integration) PromptContextText(ctx context.Context, scope iruntime.PromptScope) string { + return BuildPromptContextText(ctx, scope.Portal, scope.Meta, PromptContextDeps{ ShouldInjectContext: i.shouldInjectMemoryPromptContext, ShouldBootstrap: i.shouldBootstrapMemoryPromptContext, ResolveBootstrapPaths: i.resolveMemoryBootstrapPaths, @@ -138,16 +135,16 @@ func (i *Integration) OnCompactionLifecycle(ctx context.Context, evt iruntime.Co } switch evt.Phase { case iruntime.CompactionLifecycleStart: - i.host.SetModuleMeta(evt.Meta, "compaction_in_flight", true) + evt.Meta.SetModuleMetaValue("compaction_in_flight", true) case iruntime.CompactionLifecycleEnd: - i.host.SetModuleMeta(evt.Meta, "compaction_in_flight", false) - i.host.SetModuleMeta(evt.Meta, "last_compaction_at", time.Now().UnixMilli()) - i.host.SetModuleMeta(evt.Meta, "last_compaction_dropped_count", evt.DroppedCount) + evt.Meta.SetModuleMetaValue("compaction_in_flight", false) + evt.Meta.SetModuleMetaValue("last_compaction_at", time.Now().UnixMilli()) + evt.Meta.SetModuleMetaValue("last_compaction_dropped_count", evt.DroppedCount) case iruntime.CompactionLifecycleFail: - i.host.SetModuleMeta(evt.Meta, "compaction_in_flight", false) - i.host.SetModuleMeta(evt.Meta, "last_compaction_error", strings.TrimSpace(evt.Error)) + evt.Meta.SetModuleMetaValue("compaction_in_flight", false) + evt.Meta.SetModuleMetaValue("last_compaction_error", strings.TrimSpace(evt.Error)) case iruntime.CompactionLifecycleRefresh: - i.host.SetModuleMeta(evt.Meta, "last_compaction_refresh_at", time.Now().UnixMilli()) + evt.Meta.SetModuleMetaValue("last_compaction_refresh_at", time.Now().UnixMilli()) } if evt.Portal == nil { return @@ -204,11 +201,6 @@ func (i *Integration) buildCommandExecDeps() CommandExecDeps { } } -func asOverflowCall(call any) (iruntime.ContextOverflowCall, bool) { - oc, ok := call.(iruntime.ContextOverflowCall) - return oc, ok -} - func toInt64(v any) int64 { switch n := v.(type) { case int64: @@ -228,50 +220,39 @@ func (i *Integration) buildOverflowDeps() OverflowDeps { TrimPrompt: func(prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion { return i.host.SmartTruncatePrompt(prompt, 0.5) }, - ContextWindow: func(call any) int { - oc, ok := asOverflowCall(call) - if !ok { - return 128000 - } - return i.host.ContextWindow(oc.Meta) + ContextWindow: func(call iruntime.ContextOverflowCall) int { + return i.host.ContextWindow(call.Meta) }, ReserveTokens: func() int { return i.host.CompactorReserveTokens() }, - EffectiveModel: func(call any) string { - oc, ok := asOverflowCall(call) - if !ok { - return "" - } - return i.host.EffectiveModel(oc.Meta) + EffectiveModel: func(call iruntime.ContextOverflowCall) string { + return i.host.EffectiveModel(call.Meta) }, EstimateTokens: func(prompt []openai.ChatCompletionMessageParamUnion, model string) int { return i.host.EstimateTokens(prompt, model) }, - AlreadyFlushed: func(call any) bool { - oc, ok := asOverflowCall(call) - if !ok { + AlreadyFlushed: func(call iruntime.ContextOverflowCall) bool { + if call.Meta == nil { return false } - flushAtMs := toInt64(i.host.GetModuleMeta(oc.Meta, "overflow_flush_at")) + flushAtMs := toInt64(call.Meta.ModuleMetaValue("overflow_flush_at")) if flushAtMs == 0 { return false } - flushCC := toInt64(i.host.GetModuleMeta(oc.Meta, "overflow_flush_compaction_count")) - return int(flushCC) == i.host.CompactionCount(oc.Meta) + flushCC := toInt64(call.Meta.ModuleMetaValue("overflow_flush_compaction_count")) + return int(flushCC) == call.Meta.CompactionCounter() }, - MarkFlushed: func(ctx context.Context, call any) { - oc, _ := asOverflowCall(call) - if oc.Portal == nil || oc.Meta == nil { + MarkFlushed: func(ctx context.Context, call iruntime.ContextOverflowCall) { + if call.Portal == nil || call.Meta == nil { return } - i.host.SetModuleMeta(oc.Meta, "overflow_flush_at", time.Now().UnixMilli()) - i.host.SetModuleMeta(oc.Meta, "overflow_flush_compaction_count", i.host.CompactionCount(oc.Meta)) - _ = i.host.SavePortal(ctx, oc.Portal, "overflow flush") + call.Meta.SetModuleMetaValue("overflow_flush_at", time.Now().UnixMilli()) + call.Meta.SetModuleMetaValue("overflow_flush_compaction_count", call.Meta.CompactionCounter()) + _ = i.host.SavePortal(ctx, call.Portal, "overflow flush") }, - RunFlushToolLoop: func(ctx context.Context, call any, model string, prompt []openai.ChatCompletionMessageParamUnion) (bool, error) { - oc, _ := asOverflowCall(call) - return i.runFlushToolLoop(ctx, oc.Portal, oc.Meta, model, prompt) + RunFlushToolLoop: func(ctx context.Context, call iruntime.ContextOverflowCall, model string, prompt []openai.ChatCompletionMessageParamUnion) (bool, error) { + return i.runFlushToolLoop(ctx, call.Portal, call.Meta, model, prompt) }, OnError: func(_ context.Context, err error) { i.host.Logger().Warn("overflow flush failed", map[string]any{"error": err.Error()}) @@ -279,7 +260,7 @@ func (i *Integration) buildOverflowDeps() OverflowDeps { } } -func (i *Integration) shouldInjectMemoryPromptContext(scope iruntime.PromptScope) bool { +func (i *Integration) shouldInjectMemoryPromptContext(_ *bridgev2.Portal, _ iruntime.Meta) bool { if cfg := i.host.ModuleConfig(moduleName); cfg != nil { inject, _ := cfg["inject_context"].(bool) return inject @@ -287,15 +268,18 @@ func (i *Integration) shouldInjectMemoryPromptContext(scope iruntime.PromptScope return false } -func (i *Integration) shouldBootstrapMemoryPromptContext(scope iruntime.PromptScope) bool { - raw := i.host.GetModuleMeta(scope.Meta, "memory_bootstrap_at") +func (i *Integration) shouldBootstrapMemoryPromptContext(_ *bridgev2.Portal, meta iruntime.Meta) bool { + if meta == nil { + return false + } + raw := meta.ModuleMetaValue("memory_bootstrap_at") if raw == nil { return true } return toInt64(raw) == 0 } -func (i *Integration) resolveMemoryBootstrapPaths(_ iruntime.PromptScope) []string { +func (i *Integration) resolveMemoryBootstrapPaths(_ *bridgev2.Portal, _ iruntime.Meta) []string { _, loc := i.host.UserTimezone() if loc == nil { loc = time.UTC @@ -309,19 +293,16 @@ func (i *Integration) resolveMemoryBootstrapPaths(_ iruntime.PromptScope) []stri } } -func (i *Integration) markMemoryPromptBootstrapped(ctx context.Context, scope iruntime.PromptScope) { - if scope.Portal == nil || scope.Meta == nil { +func (i *Integration) markMemoryPromptBootstrapped(ctx context.Context, portal *bridgev2.Portal, meta iruntime.Meta) { + if portal == nil || meta == nil { return } - i.host.SetModuleMeta(scope.Meta, "memory_bootstrap_at", time.Now().UnixMilli()) - _ = i.host.SavePortal(ctx, scope.Portal, "memory bootstrap") + meta.SetModuleMetaValue("memory_bootstrap_at", time.Now().UnixMilli()) + _ = i.host.SavePortal(ctx, portal, "memory bootstrap") } -func (i *Integration) readMemoryPromptSection(ctx context.Context, scope iruntime.PromptScope, path string) string { - agentID := "" - if scope.Meta != nil { - agentID = i.host.AgentIDFromMeta(scope.Meta) - } +func (i *Integration) readMemoryPromptSection(ctx context.Context, meta iruntime.Meta, path string) string { + agentID := i.agentIDFromEventMeta(meta) content, filePath, found, err := i.host.ReadTextFile(ctx, agentID, path) if err != nil || !found { return "" @@ -358,8 +339,8 @@ func (i *Integration) getManager(agentID string) (*MemorySearchManager, string) func (i *Integration) runFlushToolLoop( ctx context.Context, - portal any, - meta any, + portal *bridgev2.Portal, + meta iruntime.Meta, model string, messages []openai.ChatCompletionMessageParamUnion, ) (bool, error) { @@ -460,28 +441,20 @@ func (i *Integration) writeMemoryCommandFile( content string, maxBytes int, ) (string, error) { - agentID := "" - if scope.Meta != nil { - agentID = i.host.AgentIDFromMeta(scope.Meta) - } + agentID := i.agentIDFromEventMeta(scope.Meta) return i.host.WriteTextFile(ctx, scope.Portal, scope.Meta, agentID, mode, path, content, maxBytes) } -func (i *Integration) agentIDFromEventMeta(meta any) string { +func (i *Integration) agentIDFromEventMeta(meta iruntime.Meta) string { var rawAgentID string if meta != nil { - rawAgentID = i.host.AgentIDFromMeta(meta) + rawAgentID = meta.AgentID() } return i.host.ResolveAgentID(rawAgentID, i.host.DefaultAgentID()) } func (i *Integration) resolveBridgeDB() *dbutil.Database { - raw := i.host.BridgeDB() - if raw == nil { - return nil - } - db, _ := raw.(*dbutil.Database) - return db + return i.host.BridgeDB() } // splitQuotedArgs parses a raw argument string into tokens, respecting quoted segments. diff --git a/pkg/integrations/memory/manager.go b/pkg/integrations/memory/manager.go index 8f96e561..36705077 100644 --- a/pkg/integrations/memory/manager.go +++ b/pkg/integrations/memory/manager.go @@ -115,11 +115,7 @@ func GetMemorySearchManager(host iruntime.Host, agentID string) (*MemorySearchMa if host == nil { return nil, "memory search unavailable" } - rawDB := host.BridgeDB() - if rawDB == nil { - return nil, "memory search unavailable" - } - db, _ := rawDB.(*dbutil.Database) + db := host.BridgeDB() if db == nil { return nil, "memory search unavailable" } diff --git a/pkg/integrations/memory/module_exec.go b/pkg/integrations/memory/module_exec.go index 5eb2eb0f..9e7142e3 100644 --- a/pkg/integrations/memory/module_exec.go +++ b/pkg/integrations/memory/module_exec.go @@ -196,7 +196,6 @@ func ExecuteCommand(ctx context.Context, call iruntime.CommandCall, deps Command } action := strings.ToLower(strings.TrimSpace(call.Args[0])) scope := iruntime.ToolScope{ - Client: call.Scope.Client, Portal: call.Scope.Portal, Meta: call.Scope.Meta, } diff --git a/pkg/integrations/memory/overflow_exec.go b/pkg/integrations/memory/overflow_exec.go index dec9ccfb..5db54a74 100644 --- a/pkg/integrations/memory/overflow_exec.go +++ b/pkg/integrations/memory/overflow_exec.go @@ -22,19 +22,19 @@ type FlushSettings struct { type OverflowDeps struct { ResolveSettings func() *FlushSettings TrimPrompt func(prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion - ContextWindow func(call any) int + ContextWindow func(call iruntime.ContextOverflowCall) int ReserveTokens func() int - EffectiveModel func(call any) string + EffectiveModel func(call iruntime.ContextOverflowCall) string EstimateTokens func(prompt []openai.ChatCompletionMessageParamUnion, model string) int - AlreadyFlushed func(call any) bool - MarkFlushed func(ctx context.Context, call any) - RunFlushToolLoop func(ctx context.Context, call any, model string, prompt []openai.ChatCompletionMessageParamUnion) (bool, error) + AlreadyFlushed func(call iruntime.ContextOverflowCall) bool + MarkFlushed func(ctx context.Context, call iruntime.ContextOverflowCall) + RunFlushToolLoop func(ctx context.Context, call iruntime.ContextOverflowCall, model string, prompt []openai.ChatCompletionMessageParamUnion) (bool, error) OnError func(ctx context.Context, err error) } func HandleOverflow( ctx context.Context, - call any, + call iruntime.ContextOverflowCall, prompt []openai.ChatCompletionMessageParamUnion, deps OverflowDeps, ) { @@ -64,8 +64,8 @@ func HandleOverflow( model = deps.EffectiveModel(call) } totalTokens := 0 - if overflowCall, ok := call.(iruntime.ContextOverflowCall); ok && overflowCall.RequestedTokens > 0 { - totalTokens = overflowCall.RequestedTokens + if call.RequestedTokens > 0 { + totalTokens = call.RequestedTokens } if totalTokens <= 0 && deps.EstimateTokens != nil { totalTokens = deps.EstimateTokens(prompt, model) @@ -102,8 +102,8 @@ func HandleOverflow( func shouldRunFlush( totalTokens, contextWindow, reserveTokens int, settings *FlushSettings, - alreadyFlushed func(call any) bool, - call any, + alreadyFlushed func(call iruntime.ContextOverflowCall) bool, + call iruntime.ContextOverflowCall, ) bool { if settings == nil { return false diff --git a/pkg/integrations/memory/prompt_exec.go b/pkg/integrations/memory/prompt_exec.go index 5a480df9..bda520f0 100644 --- a/pkg/integrations/memory/prompt_exec.go +++ b/pkg/integrations/memory/prompt_exec.go @@ -2,60 +2,56 @@ package memory import ( "context" - "slices" "strings" - "github.com/openai/openai-go/v3" + "maunium.net/go/mautrix/bridgev2" iruntime "github.com/beeper/agentremote/pkg/integrations/runtime" ) -type PromptAugmentDeps struct { - ShouldInjectContext func(scope iruntime.PromptScope) bool - ShouldBootstrap func(scope iruntime.PromptScope) bool - ResolveBootstrapPaths func(scope iruntime.PromptScope) []string - MarkBootstrapped func(ctx context.Context, scope iruntime.PromptScope) - ReadSection func(ctx context.Context, scope iruntime.PromptScope, path string) string +type PromptContextDeps struct { + ShouldInjectContext func(portal *bridgev2.Portal, meta iruntime.Meta) bool + ShouldBootstrap func(portal *bridgev2.Portal, meta iruntime.Meta) bool + ResolveBootstrapPaths func(portal *bridgev2.Portal, meta iruntime.Meta) []string + MarkBootstrapped func(ctx context.Context, portal *bridgev2.Portal, meta iruntime.Meta) + ReadSection func(ctx context.Context, meta iruntime.Meta, path string) string } -func AugmentPrompt( +func BuildPromptContextText( ctx context.Context, - scope iruntime.PromptScope, - prompt []openai.ChatCompletionMessageParamUnion, - deps PromptAugmentDeps, -) []openai.ChatCompletionMessageParamUnion { - if deps.ShouldInjectContext == nil || !deps.ShouldInjectContext(scope) { - return prompt + portal *bridgev2.Portal, + meta iruntime.Meta, + deps PromptContextDeps, +) string { + if deps.ShouldInjectContext == nil || !deps.ShouldInjectContext(portal, meta) { + return "" } if deps.ReadSection == nil { - return prompt + return "" } sections := make([]string, 0, 3) - if section := deps.ReadSection(ctx, scope, "MEMORY.md"); section != "" { + if section := deps.ReadSection(ctx, meta, "MEMORY.md"); section != "" { sections = append(sections, section) - } else if section := deps.ReadSection(ctx, scope, "memory.md"); section != "" { + } else if section := deps.ReadSection(ctx, meta, "memory.md"); section != "" { sections = append(sections, section) } - if deps.ShouldBootstrap != nil && deps.ShouldBootstrap(scope) { + if deps.ShouldBootstrap != nil && deps.ShouldBootstrap(portal, meta) { if deps.ResolveBootstrapPaths != nil { - for _, path := range deps.ResolveBootstrapPaths(scope) { - if section := deps.ReadSection(ctx, scope, path); section != "" { + for _, path := range deps.ResolveBootstrapPaths(portal, meta) { + if section := deps.ReadSection(ctx, meta, path); section != "" { sections = append(sections, section) } } } if deps.MarkBootstrapped != nil { - deps.MarkBootstrapped(ctx, scope) + deps.MarkBootstrapped(ctx, portal, meta) } } if len(sections) == 0 { - return prompt + return "" } - contextText := strings.Join(sections, "\n\n") - out := slices.Clone(prompt) - out = append(out, openai.SystemMessage(contextText)) - return out + return strings.Join(sections, "\n\n") } diff --git a/pkg/integrations/memory/sessions.go b/pkg/integrations/memory/sessions.go index 229dd206..02b340cd 100644 --- a/pkg/integrations/memory/sessions.go +++ b/pkg/integrations/memory/sessions.go @@ -39,11 +39,7 @@ func (m *MemorySearchManager) activeSessionPortals(ctx context.Context) (map[str if key == "" { continue } - portalKey, ok := info.PortalKey.(networkid.PortalKey) - if !ok { - continue - } - active[key] = sessionPortal{key: key, portalKey: portalKey} + active[key] = sessionPortal{key: key, portalKey: info.PortalKey} } return active, nil } diff --git a/pkg/integrations/runtime/helpers.go b/pkg/integrations/runtime/helpers.go index bcf5f733..d63a138a 100644 --- a/pkg/integrations/runtime/helpers.go +++ b/pkg/integrations/runtime/helpers.go @@ -7,15 +7,12 @@ import ( ) // ZerologFromHost extracts a zerolog.Logger from a Host. -// Returns zerolog.Nop() if the underlying logger is not a zerolog.Logger. +// Returns zerolog.Nop() if the host is nil. func ZerologFromHost(host Host) zerolog.Logger { if host == nil { return zerolog.Nop() } - if zl, ok := host.RawLogger().(zerolog.Logger); ok { - return zl - } - return zerolog.Nop() + return host.RawLogger() } // ModuleOrNil returns nil when the host is absent, otherwise it constructs the module. diff --git a/pkg/integrations/runtime/host_types.go b/pkg/integrations/runtime/host_types.go index 29736c4c..aa450783 100644 --- a/pkg/integrations/runtime/host_types.go +++ b/pkg/integrations/runtime/host_types.go @@ -1,6 +1,19 @@ package runtime -import "github.com/openai/openai-go/v3" +import ( + "maunium.net/go/mautrix/bridgev2/networkid" + + "github.com/openai/openai-go/v3" +) + +// Meta describes the portal metadata behavior integration modules depend on. +type Meta interface { + ModuleMetaValue(key string) any + SetModuleMetaValue(key string, value any) + AgentID() string + CompactionCounter() int + InternalRoom() bool +} // MessageSummary is a generic message summary. type MessageSummary struct { @@ -33,5 +46,5 @@ type CompletionResult struct { // SessionPortalInfo is a generic portal reference for session listing. type SessionPortalInfo struct { Key string - PortalKey any + PortalKey networkid.PortalKey } diff --git a/pkg/integrations/runtime/interfaces.go b/pkg/integrations/runtime/interfaces.go index 4c5672b6..88b90570 100644 --- a/pkg/integrations/runtime/interfaces.go +++ b/pkg/integrations/runtime/interfaces.go @@ -3,7 +3,7 @@ package runtime import ( "context" - "github.com/openai/openai-go/v3" + "maunium.net/go/mautrix/bridgev2" ) // SettingSource indicates where a setting value came from. @@ -29,9 +29,8 @@ type ToolDefinition struct { // ToolScope carries integration context without coupling to connector internals. type ToolScope struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta } // ToolCall is a concrete tool execution request. @@ -42,13 +41,6 @@ type ToolCall struct { Scope ToolScope } -// PromptScope carries prompt-building context without coupling to connector internals. -type PromptScope struct { - Client any - Portal any - Meta any -} - // ToolIntegration is the pluggable surface for tool definitions/availability/execution. type ToolIntegration interface { Name() string @@ -57,19 +49,24 @@ type ToolIntegration interface { ToolAvailability(ctx context.Context, scope ToolScope, toolName string) (known bool, available bool, source SettingSource, reason string) } -// PromptIntegration is the pluggable surface for prompt/system message augmentation. -type PromptIntegration interface { - Name() string - AdditionalSystemMessages(ctx context.Context, scope PromptScope) []openai.ChatCompletionMessageParamUnion - AugmentPrompt(ctx context.Context, scope PromptScope, prompt []openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion -} - // ToolApprovalIntegration is an optional seam for tool approval policy overrides. type ToolApprovalIntegration interface { Name() string ToolApprovalRequirement(toolName string, args map[string]any) (handled bool, required bool, action string) } +// PromptScope carries typed prompt-building context. +type PromptScope struct { + Portal *bridgev2.Portal + Meta Meta +} + +// PromptContextIntegration contributes additional prompt context text. +type PromptContextIntegration interface { + Name() string + PromptContextText(ctx context.Context, scope PromptScope) string +} + // LifecycleIntegration is an optional capability for integrations that need runtime start/stop hooks. type LifecycleIntegration interface { Start(ctx context.Context) error diff --git a/pkg/integrations/runtime/module_hooks.go b/pkg/integrations/runtime/module_hooks.go index c4d1df31..175b7a3d 100644 --- a/pkg/integrations/runtime/module_hooks.go +++ b/pkg/integrations/runtime/module_hooks.go @@ -5,6 +5,9 @@ import ( "time" "github.com/openai/openai-go/v3" + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/bridgev2" ) // ModuleHooks is the base contract every integration module implements. @@ -28,10 +31,8 @@ type CommandDefinition struct { // CommandScope carries command execution context without importing connector internals. type CommandScope struct { - Client any - Portal any - Meta any - Event any + Portal *bridgev2.Portal + Meta Meta } // CommandCall is a concrete command execution request. @@ -63,9 +64,8 @@ const ( // SessionMutationEvent is emitted when chat/session data changes. type SessionMutationEvent struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta SessionKey string Force bool Kind SessionMutationKind @@ -73,9 +73,8 @@ type SessionMutationEvent struct { // FileChangedEvent is emitted when a file write/edit/apply_patch updates workspace data. type FileChangedEvent struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta Path string } @@ -99,9 +98,8 @@ const ( // CompactionLifecycleEvent provides compaction lifecycle details to integrations. type CompactionLifecycleEvent struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta Phase CompactionLifecyclePhase Attempt int ContextWindowTokens int @@ -125,9 +123,8 @@ type CompactionLifecycleIntegration interface { // ContextOverflowCall contains context-overflow retry state. type ContextOverflowCall struct { - Client any - Portal any - Meta any + Portal *bridgev2.Portal + Meta Meta Prompt []openai.ChatCompletionMessageParamUnion RequestedTokens int ModelMaxTokens int @@ -136,8 +133,6 @@ type ContextOverflowCall struct { // LoginScope carries per-login cleanup scope. type LoginScope struct { - Client any - Login any BridgeID string LoginID string } @@ -153,74 +148,40 @@ type LoginPurgeIntegration interface { // nested capability objects or type-asserting optional host adapters. type Host interface { Logger() Logger - RawLogger() any + RawLogger() zerolog.Logger Now() time.Time - ResolvePortalByRoomID(ctx context.Context, roomID string) any - ResolveDefaultPortal(ctx context.Context) any - ResolveLastActivePortal(ctx context.Context, agentID string) any - DispatchInternalMessage(ctx context.Context, portal any, meta any, message string, source string) error - SendAssistantMessage(ctx context.Context, portal any, body string) error - RequestNow(ctx context.Context, reason string) - ToolDefinitionByName(name string) (ToolDefinition, bool) - ExecuteBuiltinTool(ctx context.Context, scope ToolScope, name string, rawArgsJSON string) (string, error) ResolveWorkspaceDir() string - BridgeDB() any + BridgeDB() *dbutil.Database BridgeID() string LoginID() string ModuleEnabled(name string) bool ModuleConfig(name string) map[string]any AgentModuleConfig(agentID string, module string) map[string]any - GetOrCreatePortal(ctx context.Context, portalID string, receiver string, displayName string, setupMeta func(meta any)) (portal any, roomID string, err error) - SavePortal(ctx context.Context, portal any, reason string) error - PortalRoomID(portal any) string - PortalKeyString(portal any) string - - GetModuleMeta(meta any, key string) any - SetModuleMeta(meta any, key string, value any) - AgentIDFromMeta(meta any) string - CompactionCount(meta any) int - IsGroupChat(ctx context.Context, portal any) bool - IsInternalRoom(meta any) bool - PortalMeta(portal any) any - CloneMeta(portal any) any - SetMetaField(meta any, key string, value any) - - RecentMessages(ctx context.Context, portal any, count int) []MessageSummary - LastAssistantMessage(ctx context.Context, portal any) (id string, timestamp int64) - WaitForAssistantMessage(ctx context.Context, portal any, afterID string, afterTS int64) (*AssistantMessageInfo, bool) - - RunHeartbeatOnce(ctx context.Context, reason string) (status string, reasonMsg string) - ResolveHeartbeatSessionPortal(agentID string) (portal any, sessionKey string, err error) - ResolveHeartbeatSessionKey(agentID string) string - HeartbeatAckMaxChars(agentID string) int - EnqueueSystemEvent(sessionKey string, text string, agentID string) - PersistSystemEvents() - ResolveLastTarget(agentID string) (channel string, target string, ok bool) + SavePortal(ctx context.Context, portal *bridgev2.Portal, reason string) error + PortalRoomID(portal *bridgev2.Portal) string + PortalKeyString(portal *bridgev2.Portal) string + + IsGroupChat(ctx context.Context, portal *bridgev2.Portal) bool + + RecentMessages(ctx context.Context, portal *bridgev2.Portal, count int) []MessageSummary ResolveAgentID(raw string, fallbackDefault string) string - NormalizeAgentID(raw string) string - AgentExists(normalizedID string) bool DefaultAgentID() string - AgentTimeoutSeconds() int UserTimezone() (tz string, loc *time.Location) - NormalizeThinkingLevel(raw string) (string, bool) - - EffectiveModel(meta any) string - ContextWindow(meta any) int - MergeDisconnectContext(ctx context.Context) (context.Context, context.CancelFunc) - BackgroundContext(ctx context.Context) context.Context + EffectiveModel(meta Meta) string + ContextWindow(meta Meta) int - NewCompletion(ctx context.Context, model string, messages []openai.ChatCompletionMessageParamUnion, toolParams any) (*CompletionResult, error) + NewCompletion(ctx context.Context, model string, messages []openai.ChatCompletionMessageParamUnion, toolParams []openai.ChatCompletionToolUnionParam) (*CompletionResult, error) - IsToolEnabled(meta any, toolName string) bool + IsToolEnabled(meta Meta, toolName string) bool AllToolDefinitions() []ToolDefinition - ExecuteToolInContext(ctx context.Context, portal any, meta any, name string, argsJSON string) (string, error) - ToolsToOpenAIParams(tools []ToolDefinition) any + ExecuteToolInContext(ctx context.Context, portal *bridgev2.Portal, meta Meta, name string, argsJSON string) (string, error) + ToolsToOpenAIParams(tools []ToolDefinition) []openai.ChatCompletionToolUnionParam ReadTextFile(ctx context.Context, agentID string, path string) (content string, filePath string, found bool, err error) - WriteTextFile(ctx context.Context, portal any, meta any, agentID string, mode string, path string, content string, maxBytes int) (finalPath string, err error) + WriteTextFile(ctx context.Context, portal *bridgev2.Portal, meta Meta, agentID string, mode string, path string, content string, maxBytes int) (finalPath string, err error) SmartTruncatePrompt(prompt []openai.ChatCompletionMessageParamUnion, ratio float64) []openai.ChatCompletionMessageParamUnion EstimateTokens(prompt []openai.ChatCompletionMessageParamUnion, model string) int @@ -230,7 +191,7 @@ type Host interface { IsLoggedIn() bool SessionPortals(ctx context.Context, loginID string, agentID string) ([]SessionPortalInfo, error) - LoginDB() any + LoginDB() *dbutil.Database } // Logger is a minimal structured logger abstraction. diff --git a/sdk/client.go b/sdk/client.go index 14f9da5d..67d4fa5b 100644 --- a/sdk/client.go +++ b/sdk/client.go @@ -17,18 +17,18 @@ import ( // Compile-time interface checks. var ( - _ bridgev2.NetworkAPI = (*sdkClient)(nil) - _ bridgev2.EditHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.ReactionHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.RedactionHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.TypingHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.RoomNameHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.RoomTopicHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.BackfillingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.DeleteChatHandlingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.IdentifierResolvingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.ContactListingNetworkAPI = (*sdkClient)(nil) - _ bridgev2.UserSearchingNetworkAPI = (*sdkClient)(nil) + _ bridgev2.NetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.EditHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.ReactionHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.RedactionHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.TypingHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.RoomNameHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.RoomTopicHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.BackfillingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.DeleteChatHandlingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.IdentifierResolvingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.ContactListingNetworkAPI = (*sdkClient[any, any])(nil) + _ bridgev2.UserSearchingNetworkAPI = (*sdkClient[any, any])(nil) ) // pendingSDKApprovalData holds SDK-specific metadata for a pending tool approval. @@ -39,19 +39,19 @@ type pendingSDKApprovalData struct { ToolName string } -type sdkClient struct { +type sdkClient[SessionT SessionValue, ConfigDataT ConfigValue] struct { agentremote.ClientBase - cfg *Config + cfg *Config[SessionT, ConfigDataT] userLogin *bridgev2.UserLogin approvalFlow *agentremote.ApprovalFlow[*pendingSDKApprovalData] turnManager *TurnManager conversationState *conversationStateStore sessionMu sync.RWMutex - session any + session SessionT } -func newSDKClient(login *bridgev2.UserLogin, cfg *Config) *sdkClient { +func newSDKClient[SessionT SessionValue, ConfigDataT ConfigValue](login *bridgev2.UserLogin, cfg *Config[SessionT, ConfigDataT]) *sdkClient[SessionT, ConfigDataT] { identity := resolveProviderIdentity(cfg) senderForPortal := func(*bridgev2.Portal) bridgev2.EventSender { if cfg != nil && cfg.Agent != nil { @@ -59,7 +59,7 @@ func newSDKClient(login *bridgev2.UserLogin, cfg *Config) *sdkClient { } return bridgev2.EventSender{} } - c := &sdkClient{ + c := &sdkClient[SessionT, ConfigDataT]{ cfg: cfg, userLogin: login, conversationState: newConversationStateStore(), @@ -86,44 +86,82 @@ func newSDKClient(login *bridgev2.UserLogin, cfg *Config) *sdkClient { return c } -func (c *sdkClient) GetApprovalHandler() agentremote.ApprovalReactionHandler { +func (c *sdkClient[SessionT, ConfigDataT]) GetApprovalHandler() agentremote.ApprovalReactionHandler { return c.approvalFlow } -func (c *sdkClient) config() *Config { return c.cfg } +func (c *sdkClient[SessionT, ConfigDataT]) agent() *Agent { + if c == nil || c.cfg == nil { + return nil + } + return c.cfg.Agent +} -func (c *sdkClient) sessionValue() any { return c.getSession() } +func (c *sdkClient[SessionT, ConfigDataT]) agentCatalog() AgentCatalog { + if c == nil || c.cfg == nil { + return nil + } + return c.cfg.AgentCatalog +} -func (c *sdkClient) conversationStore() *conversationStateStore { return c.conversationState } +func (c *sdkClient[SessionT, ConfigDataT]) roomFeatures(conv *Conversation) *RoomFeatures { + if c == nil || c.cfg == nil { + return nil + } + if c.cfg.GetCapabilities != nil { + if rf := c.cfg.GetCapabilities(c.getSession(), conv); rf != nil { + return rf + } + } + return c.cfg.RoomFeatures +} + +func (c *sdkClient[SessionT, ConfigDataT]) commands() []Command { + if c == nil || c.cfg == nil { + return nil + } + return c.cfg.Commands +} + +func (c *sdkClient[SessionT, ConfigDataT]) turnConfig() *TurnConfig { + if c == nil || c.cfg == nil { + return nil + } + return c.cfg.TurnManagement +} + +func (c *sdkClient[SessionT, ConfigDataT]) conversationStore() *conversationStateStore { + return c.conversationState +} -func (c *sdkClient) approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] { +func (c *sdkClient[SessionT, ConfigDataT]) approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] { return c.approvalFlow } -func (c *sdkClient) providerIdentity() ProviderIdentity { +func (c *sdkClient[SessionT, ConfigDataT]) providerIdentity() ProviderIdentity { return resolveProviderIdentity(c.cfg) } -func (c *sdkClient) getSession() any { +func (c *sdkClient[SessionT, ConfigDataT]) getSession() SessionT { c.sessionMu.RLock() defer c.sessionMu.RUnlock() return c.session } -func (c *sdkClient) setSession(s any) { +func (c *sdkClient[SessionT, ConfigDataT]) setSession(s SessionT) { c.sessionMu.Lock() c.session = s c.sessionMu.Unlock() } // Connect implements bridgev2.NetworkAPI. -func (c *sdkClient) Connect(ctx context.Context) { - if c.config().OnConnect != nil { +func (c *sdkClient[SessionT, ConfigDataT]) Connect(ctx context.Context) { + if c.cfg != nil && c.cfg.OnConnect != nil { info := &LoginInfo{ Login: c.userLogin, UserID: string(c.userLogin.UserMXID), } - session, err := c.config().OnConnect(ctx, info) + session, err := c.cfg.OnConnect(ctx, info) if err != nil { c.userLogin.BridgeState.Send(status.BridgeState{ StateEvent: status.StateUnknownError, @@ -137,55 +175,56 @@ func (c *sdkClient) Connect(ctx context.Context) { c.userLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) } -func (c *sdkClient) Disconnect() { +func (c *sdkClient[SessionT, ConfigDataT]) Disconnect() { c.SetLoggedIn(false) if c.approvalFlow != nil { c.approvalFlow.Close() } c.CloseAllSessions() - if c.config().OnDisconnect != nil { - c.config().OnDisconnect(c.getSession()) + if c.cfg != nil && c.cfg.OnDisconnect != nil { + c.cfg.OnDisconnect(c.getSession()) } - c.setSession(nil) + var zero SessionT + c.setSession(zero) } -func (c *sdkClient) LogoutRemote(ctx context.Context) { +func (c *sdkClient[SessionT, ConfigDataT]) LogoutRemote(ctx context.Context) { c.Disconnect() } -func (c *sdkClient) IsThisUser(_ context.Context, userID networkid.UserID) bool { - if c.config().IsThisUser != nil { - return c.config().IsThisUser(string(userID)) +func (c *sdkClient[SessionT, ConfigDataT]) IsThisUser(_ context.Context, userID networkid.UserID) bool { + if c.cfg != nil && c.cfg.IsThisUser != nil { + return c.cfg.IsThisUser(string(userID)) } return false } -func (c *sdkClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { - if c.config().GetChatInfo != nil { - return c.config().GetChatInfo(c.conv(ctx, portal)) +func (c *sdkClient[SessionT, ConfigDataT]) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { + if c.cfg != nil && c.cfg.GetChatInfo != nil { + return c.cfg.GetChatInfo(c.conv(ctx, portal)) } return nil, nil } -func (c *sdkClient) GetUserInfo(_ context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { - if c.config().GetUserInfo != nil { - return c.config().GetUserInfo(ghost) +func (c *sdkClient[SessionT, ConfigDataT]) GetUserInfo(_ context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { + if c.cfg != nil && c.cfg.GetUserInfo != nil { + return c.cfg.GetUserInfo(ghost) } return nil, nil } -func (c *sdkClient) GetCapabilities(_ context.Context, portal *bridgev2.Portal) *event.RoomFeatures { +func (c *sdkClient[SessionT, ConfigDataT]) GetCapabilities(_ context.Context, portal *bridgev2.Portal) *event.RoomFeatures { conv := c.conv(context.Background(), portal) return convertRoomFeatures(conv.currentRoomFeatures(context.Background())) } -func (c *sdkClient) conv(ctx context.Context, portal *bridgev2.Portal) *Conversation { +func (c *sdkClient[SessionT, ConfigDataT]) conv(ctx context.Context, portal *bridgev2.Portal) *Conversation { return newConversation(ctx, portal, c.userLogin, bridgev2.EventSender{}, c) } // HandleMatrixMessage dispatches incoming messages to the OnMessage callback. -func (c *sdkClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (*bridgev2.MatrixMessageResponse, error) { - if c.config().OnMessage == nil { +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (*bridgev2.MatrixMessageResponse, error) { + if c.cfg == nil || c.cfg.OnMessage == nil { return nil, nil } runCtx := c.BackgroundContext(ctx) @@ -203,7 +242,7 @@ func (c *sdkClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri roomID = c.turnManager.ResolveKey(roomID) } run := func(turnCtx context.Context) error { - return c.config().OnMessage(session, conv, sdkMsg, turn) + return c.cfg.OnMessage(session, conv, sdkMsg, turn) } go func() { var err error @@ -272,8 +311,8 @@ func convertMatrixMessage(msg *bridgev2.MatrixMessage) *Message { } // HandleMatrixEdit implements bridgev2.EditHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixEdit(ctx context.Context, edit *bridgev2.MatrixEdit) error { - if c.config().OnEdit == nil { +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixEdit(ctx context.Context, edit *bridgev2.MatrixEdit) error { + if c.cfg == nil || c.cfg.OnEdit == nil { return nil } me := &MessageEdit{ @@ -284,81 +323,81 @@ func (c *sdkClient) HandleMatrixEdit(ctx context.Context, edit *bridgev2.MatrixE me.NewText = edit.Content.Body me.NewHTML = edit.Content.FormattedBody } - return c.config().OnEdit(c.getSession(), c.conv(ctx, edit.Portal), me) + return c.cfg.OnEdit(c.getSession(), c.conv(ctx, edit.Portal), me) } // HandleMatrixMessageRemove implements bridgev2.RedactionHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixMessageRemove(ctx context.Context, msg *bridgev2.MatrixMessageRemove) error { - if c.config().OnDelete == nil { +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixMessageRemove(ctx context.Context, msg *bridgev2.MatrixMessageRemove) error { + if c.cfg == nil || c.cfg.OnDelete == nil { return nil } var msgID string if msg.TargetMessage != nil { msgID = string(msg.TargetMessage.ID) } - return c.config().OnDelete(c.getSession(), c.conv(ctx, msg.Portal), msgID) + return c.cfg.OnDelete(c.getSession(), c.conv(ctx, msg.Portal), msgID) } // HandleMatrixTyping implements bridgev2.TypingHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixTyping(ctx context.Context, msg *bridgev2.MatrixTyping) error { - if c.config().OnTyping != nil { - c.config().OnTyping(c.getSession(), c.conv(ctx, msg.Portal), msg.IsTyping) +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixTyping(ctx context.Context, msg *bridgev2.MatrixTyping) error { + if c.cfg != nil && c.cfg.OnTyping != nil { + c.cfg.OnTyping(c.getSession(), c.conv(ctx, msg.Portal), msg.IsTyping) } return nil } // HandleMatrixRoomName implements bridgev2.RoomNameHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixRoomName(ctx context.Context, msg *bridgev2.MatrixRoomName) (bool, error) { - if c.config().OnRoomName != nil { - return c.config().OnRoomName(c.getSession(), c.conv(ctx, msg.Portal), msg.Content.Name) +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixRoomName(ctx context.Context, msg *bridgev2.MatrixRoomName) (bool, error) { + if c.cfg != nil && c.cfg.OnRoomName != nil { + return c.cfg.OnRoomName(c.getSession(), c.conv(ctx, msg.Portal), msg.Content.Name) } return false, nil } // HandleMatrixRoomTopic implements bridgev2.RoomTopicHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixRoomTopic(ctx context.Context, msg *bridgev2.MatrixRoomTopic) (bool, error) { - if c.config().OnRoomTopic != nil { - return c.config().OnRoomTopic(c.getSession(), c.conv(ctx, msg.Portal), msg.Content.Topic) +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixRoomTopic(ctx context.Context, msg *bridgev2.MatrixRoomTopic) (bool, error) { + if c.cfg != nil && c.cfg.OnRoomTopic != nil { + return c.cfg.OnRoomTopic(c.getSession(), c.conv(ctx, msg.Portal), msg.Content.Topic) } return false, nil } // FetchMessages implements bridgev2.BackfillingNetworkAPI. -func (c *sdkClient) FetchMessages(ctx context.Context, params bridgev2.FetchMessagesParams) (*bridgev2.FetchMessagesResponse, error) { - if c.config().FetchMessages == nil { +func (c *sdkClient[SessionT, ConfigDataT]) FetchMessages(ctx context.Context, params bridgev2.FetchMessagesParams) (*bridgev2.FetchMessagesResponse, error) { + if c.cfg == nil || c.cfg.FetchMessages == nil { return nil, nil } - return c.config().FetchMessages(ctx, params) + return c.cfg.FetchMessages(ctx, params) } // HandleMatrixDeleteChat implements bridgev2.DeleteChatHandlingNetworkAPI. -func (c *sdkClient) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { - if c.config().DeleteChat == nil { +func (c *sdkClient[SessionT, ConfigDataT]) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { + if c.cfg == nil || c.cfg.DeleteChat == nil { return nil } - return c.config().DeleteChat(c.conv(ctx, msg.Portal)) + return c.cfg.DeleteChat(c.conv(ctx, msg.Portal)) } // ResolveIdentifier implements bridgev2.IdentifierResolvingNetworkAPI. -func (c *sdkClient) ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { - if c.config().ResolveIdentifier == nil { +func (c *sdkClient[SessionT, ConfigDataT]) ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { + if c.cfg == nil || c.cfg.ResolveIdentifier == nil { return nil, nil } - return c.config().ResolveIdentifier(ctx, c.getSession(), identifier, createChat) + return c.cfg.ResolveIdentifier(ctx, c.getSession(), identifier, createChat) } // GetContactList implements bridgev2.ContactListingNetworkAPI. -func (c *sdkClient) GetContactList(ctx context.Context) ([]*bridgev2.ResolveIdentifierResponse, error) { - if c.config().GetContactList == nil { +func (c *sdkClient[SessionT, ConfigDataT]) GetContactList(ctx context.Context) ([]*bridgev2.ResolveIdentifierResponse, error) { + if c.cfg == nil || c.cfg.GetContactList == nil { return nil, nil } - return c.config().GetContactList(ctx, c.getSession()) + return c.cfg.GetContactList(ctx, c.getSession()) } // SearchUsers implements bridgev2.UserSearchingNetworkAPI. -func (c *sdkClient) SearchUsers(ctx context.Context, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { - if c.config().SearchUsers == nil { +func (c *sdkClient[SessionT, ConfigDataT]) SearchUsers(ctx context.Context, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { + if c.cfg == nil || c.cfg.SearchUsers == nil { return nil, nil } - return c.config().SearchUsers(ctx, c.getSession(), query) + return c.cfg.SearchUsers(ctx, c.getSession(), query) } diff --git a/sdk/client_resolution_test.go b/sdk/client_resolution_test.go index cca50dab..ba9d44e6 100644 --- a/sdk/client_resolution_test.go +++ b/sdk/client_resolution_test.go @@ -14,8 +14,8 @@ func TestSDKClientResolveIdentifierPreservesFullResponse(t *testing.T) { chat := &bridgev2.CreateChatResponse{ PortalKey: networkid.PortalKey{ID: "portal-1", Receiver: "login-1"}, } - cfg := &Config{ - ResolveIdentifier: func(_ context.Context, _ any, id string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { + cfg := &Config[*bridgev2.UserLogin, *struct{}]{ + ResolveIdentifier: func(_ context.Context, _ *bridgev2.UserLogin, id string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { if id != "agent:test" { t.Fatalf("unexpected identifier %q", id) } @@ -50,11 +50,11 @@ func TestSDKClientResolveIdentifierPreservesFullResponse(t *testing.T) { func TestSDKClientContactListingAndSearch(t *testing.T) { contact := &bridgev2.ResolveIdentifierResponse{UserID: "agent-user"} - cfg := &Config{ - GetContactList: func(_ context.Context, _ any) ([]*bridgev2.ResolveIdentifierResponse, error) { + cfg := &Config[*bridgev2.UserLogin, *struct{}]{ + GetContactList: func(_ context.Context, _ *bridgev2.UserLogin) ([]*bridgev2.ResolveIdentifierResponse, error) { return []*bridgev2.ResolveIdentifierResponse{contact}, nil }, - SearchUsers: func(_ context.Context, _ any, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { + SearchUsers: func(_ context.Context, _ *bridgev2.UserLogin, query string) ([]*bridgev2.ResolveIdentifierResponse, error) { if query != "agent" { t.Fatalf("unexpected query %q", query) } diff --git a/sdk/commands.go b/sdk/commands.go index 9e021239..aadb7b74 100644 --- a/sdk/commands.go +++ b/sdk/commands.go @@ -15,7 +15,7 @@ import ( var sdkHelpSection = commands.HelpSection{Name: "SDK", Order: 50} // registerCommands registers Config.Commands with the bridgev2 command processor. -func registerCommands(br *bridgev2.Bridge, cfg *Config) { +func registerCommands[SessionT SessionValue, ConfigDataT ConfigValue](br *bridgev2.Bridge, cfg *Config[SessionT, ConfigDataT]) { if len(cfg.Commands) == 0 || br == nil { return } diff --git a/sdk/connector.go b/sdk/connector.go index c6061593..65fbf499 100644 --- a/sdk/connector.go +++ b/sdk/connector.go @@ -15,8 +15,7 @@ import ( ) // NewConnectorBase builds an SDK-backed connector base that can be embedded by custom bridges. -func NewConnectorBase(cfg *Config) *agentremote.ConnectorBase { - var br *bridgev2.Bridge +func NewConnectorBase[SessionT SessionValue, ConfigDataT ConfigValue](cfg *Config[SessionT, ConfigDataT]) *agentremote.ConnectorBase { mu, clientsRef := cfg.ClientCacheMu, cfg.ClientCache if mu == nil { mu = &sync.Mutex{} @@ -45,7 +44,7 @@ func NewConnectorBase(cfg *Config) *agentremote.ConnectorBase { cfg.UpdateClient(client, login) return } - if typed, ok := client.(*sdkClient); ok { + if typed, ok := client.(*sdkClient[SessionT, ConfigDataT]); ok { typed.SetUserLogin(login) } }, @@ -66,23 +65,22 @@ func NewConnectorBase(cfg *Config) *agentremote.ConnectorBase { return agentremote.NewConnector(agentremote.ConnectorSpec{ ProtocolID: protocolID, Init: func(bridge *bridgev2.Bridge) { - br = bridge agentremote.EnsureClientMap(mu, clientsRef) if cfg.InitConnector != nil { cfg.InitConnector(bridge) } }, - Start: func(ctx context.Context) error { - registerCommands(br, cfg) + Start: func(ctx context.Context, bridge *bridgev2.Bridge) error { + registerCommands(bridge, cfg) if cfg.StartConnector != nil { - return cfg.StartConnector(ctx, br) + return cfg.StartConnector(ctx, bridge) } return nil }, - Stop: func(ctx context.Context) { + Stop: func(ctx context.Context, bridge *bridgev2.Bridge) { agentremote.StopClients(mu, clientsRef) if cfg.StopConnector != nil { - cfg.StopConnector(ctx, br) + cfg.StopConnector(ctx, bridge) } }, Name: func() bridgev2.BridgeName { diff --git a/sdk/connector_helpers.go b/sdk/connector_helpers.go index 139f5221..4fd0c7e8 100644 --- a/sdk/connector_helpers.go +++ b/sdk/connector_helpers.go @@ -16,13 +16,18 @@ import ( ) // BuildStandardMetaTypes returns the common bridge metadata registrations. -func BuildStandardMetaTypes( - newPortal func() any, - newMessage func() any, - newLogin func() any, - newGhost func() any, +func BuildStandardMetaTypes[PortalT, MessageT, LoginT, GhostT any]( + newPortal func() PortalT, + newMessage func() MessageT, + newLogin func() LoginT, + newGhost func() GhostT, ) database.MetaTypes { - return agentremote.BuildMetaTypes(newPortal, newMessage, newLogin, newGhost) + return agentremote.BuildMetaTypes( + func() any { return newPortal() }, + func() any { return newMessage() }, + func() any { return newLogin() }, + func() any { return newGhost() }, + ) } // ApplyDefaultCommandPrefix sets the command prefix when it is empty. @@ -32,6 +37,16 @@ func ApplyDefaultCommandPrefix(prefix *string, value string) { } } +// ResolveCommandPrefix returns the configured prefix when present, otherwise the +// bridge's declared default prefix without mutating configuration state. +func ResolveCommandPrefix(prefix string, fallback string) string { + trimmed := strings.TrimSpace(prefix) + if trimmed != "" { + return trimmed + } + return fallback +} + // ApplyBoolDefault initializes a nil bool pointer to the provided value. func ApplyBoolDefault(target **bool, value bool) { if target == nil || *target != nil { @@ -79,7 +94,7 @@ func TypedClientUpdater[T interface { } } -type StandardConnectorConfigParams struct { +type StandardConnectorConfigParams[SessionT SessionValue, ConfigDataT ConfigValue, PortalT, MessageT, LoginT, GhostT any] struct { Name string Description string ProtocolID string @@ -87,7 +102,7 @@ type StandardConnectorConfigParams struct { ClientCacheMu *sync.Mutex ClientCache *map[networkid.UserLoginID]bridgev2.NetworkAPI AgentCatalog AgentCatalog - GetCapabilities func(session any, conv *Conversation) *RoomFeatures + GetCapabilities func(session SessionT, conv *Conversation) *RoomFeatures InitConnector func(br *bridgev2.Bridge) StartConnector func(ctx context.Context, br *bridgev2.Bridge) error StopConnector func(ctx context.Context, br *bridgev2.Bridge) @@ -99,12 +114,12 @@ type StandardConnectorConfigParams struct { DefaultPort uint16 DefaultCommandPrefix func() string ExampleConfig string - ConfigData any + ConfigData ConfigDataT ConfigUpgrader configupgrade.Upgrader - NewPortal func() any - NewMessage func() any - NewLogin func() any - NewGhost func() any + NewPortal func() PortalT + NewMessage func() MessageT + NewLogin func() LoginT + NewGhost func() GhostT NetworkCapabilities func() *bridgev2.NetworkGeneralCapabilities FillBridgeInfo func(portal *bridgev2.Portal, content *event.BridgeEventContent) AcceptLogin func(login *bridgev2.UserLogin) (bool, string) @@ -120,8 +135,8 @@ type StandardConnectorConfigParams struct { // NewStandardConnectorConfig builds the common bridgesdk.Config skeleton used by // the dedicated bridge connectors. -func NewStandardConnectorConfig(p StandardConnectorConfigParams) *Config { - return &Config{ +func NewStandardConnectorConfig[SessionT SessionValue, ConfigDataT ConfigValue, PortalT, MessageT, LoginT, GhostT any](p StandardConnectorConfigParams[SessionT, ConfigDataT, PortalT, MessageT, LoginT, GhostT]) *Config[SessionT, ConfigDataT] { + return &Config[SessionT, ConfigDataT]{ Name: p.Name, Description: p.Description, ProtocolID: p.ProtocolID, diff --git a/sdk/connector_hooks_test.go b/sdk/connector_hooks_test.go index 4fc34874..2207f7b0 100644 --- a/sdk/connector_hooks_test.go +++ b/sdk/connector_hooks_test.go @@ -57,8 +57,9 @@ func TestNewConnectorBaseUsesHooksAndCustomClients(t *testing.T) { createCalled := 0 updateCalled := 0 afterLoadCalled := 0 + wantBridge := &bridgev2.Bridge{} - cfg := &Config{ + cfg := &Config[*struct{}, *struct{}]{ Name: "hooked", ClientCacheMu: &mu, ClientCache: &clients, @@ -68,12 +69,25 @@ func TestNewConnectorBaseUsesHooksAndCustomClients(t *testing.T) { } return true, "" }, - InitConnector: func(*bridgev2.Bridge) { initCalled++ }, - StartConnector: func(context.Context, *bridgev2.Bridge) error { + InitConnector: func(got *bridgev2.Bridge) { + if got != wantBridge { + t.Fatalf("expected init bridge %p, got %p", wantBridge, got) + } + initCalled++ + }, + StartConnector: func(_ context.Context, got *bridgev2.Bridge) error { + if got != wantBridge { + t.Fatalf("expected start bridge %p, got %p", wantBridge, got) + } startCalled++ return nil }, - StopConnector: func(context.Context, *bridgev2.Bridge) { stopCalled++ }, + StopConnector: func(_ context.Context, got *bridgev2.Bridge) { + if got != wantBridge { + t.Fatalf("expected stop bridge %p, got %p", wantBridge, got) + } + stopCalled++ + }, MakeBrokenLogin: func(login *bridgev2.UserLogin, reason string) *agentremote.BrokenLoginClient { return agentremote.NewBrokenLoginClient(login, "custom:"+reason) }, @@ -89,7 +103,7 @@ func TestNewConnectorBaseUsesHooksAndCustomClients(t *testing.T) { } conn := NewConnectorBase(cfg) - conn.Init(nil) + conn.Init(wantBridge) if err := conn.Start(context.Background()); err != nil { t.Fatalf("start returned error: %v", err) } @@ -131,7 +145,7 @@ func TestNewConnectorBaseUsesHooksAndCustomClients(t *testing.T) { func TestNewConnectorBaseUsesCustomLoadLoginAndLoginFlows(t *testing.T) { loadCalled := 0 - cfg := &Config{ + cfg := &Config[*struct{}, *struct{}]{ Name: "custom-load", LoadLogin: func(_ context.Context, login *bridgev2.UserLogin) error { loadCalled++ @@ -174,7 +188,7 @@ func TestNewConnectorBaseUsesCustomLoadLoginAndLoginFlows(t *testing.T) { } func TestApprovalControllerUsesCustomHandler(t *testing.T) { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{}, nil) + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{}, nil) turn := conv.StartTurn(context.Background(), &Agent{ID: "agent"}, nil) called := false @@ -202,4 +216,13 @@ func TestApprovalControllerUsesCustomHandler(t *testing.T) { } } +func TestResolveCommandPrefixTrimsConfiguredValue(t *testing.T) { + if got := ResolveCommandPrefix(" /ai ", "!fallback"); got != "/ai" { + t.Fatalf("expected trimmed configured prefix, got %q", got) + } + if got := ResolveCommandPrefix(" ", "!fallback"); got != "!fallback" { + t.Fatalf("expected fallback prefix, got %q", got) + } +} + var _ bridgev2.NetworkAPI = (*testSDKClient)(nil) diff --git a/sdk/conversation.go b/sdk/conversation.go index ba8ec271..b031c650 100644 --- a/sdk/conversation.go +++ b/sdk/conversation.go @@ -60,13 +60,6 @@ func (c *Conversation) getIntent(ctx context.Context) (bridgev2.MatrixAPI, error return intent, nil } -func (c *Conversation) configOrNil() *Config { - if c.runtime == nil { - return nil - } - return c.runtime.config() -} - func (c *Conversation) stateStore() *conversationStateStore { if c == nil || c.runtime == nil { return nil @@ -97,15 +90,14 @@ func (c *Conversation) resolveDefaultAgent(ctx context.Context) (*Agent, error) return agent, nil } } - cfg := c.configOrNil() - if cfg == nil { + if c.runtime == nil { return nil, nil } - if cfg.Agent != nil { - return cfg.Agent, nil + if agent := c.runtime.agent(); agent != nil { + return agent, nil } - if cfg.AgentCatalog != nil { - return cfg.AgentCatalog.DefaultAgent(ctx, c.login) + if catalog := c.runtime.agentCatalog(); catalog != nil { + return catalog.DefaultAgent(ctx, c.login) } return nil, nil } @@ -114,15 +106,14 @@ func (c *Conversation) resolveAgentByIdentifier(ctx context.Context, identifier if c == nil || strings.TrimSpace(identifier) == "" { return nil, nil } - cfg := c.configOrNil() - if cfg == nil { + if c.runtime == nil { return nil, nil } - if cfg.Agent != nil && cfg.Agent.ID == identifier { - return cfg.Agent, nil + if agent := c.runtime.agent(); agent != nil && agent.ID == identifier { + return agent, nil } - if cfg.AgentCatalog != nil { - return cfg.AgentCatalog.ResolveAgent(ctx, c.login, identifier) + if catalog := c.runtime.agentCatalog(); catalog != nil { + return catalog.ResolveAgent(ctx, c.login, identifier) } return nil, nil } @@ -131,9 +122,8 @@ func (c *Conversation) currentRoomFeatures(ctx context.Context) *RoomFeatures { if c == nil { return nil } - cfg := c.configOrNil() - if cfg != nil && cfg.GetCapabilities != nil { - if rf := cfg.GetCapabilities(c.runtime.sessionValue(), c); rf != nil { + if c.runtime != nil { + if rf := c.runtime.roomFeatures(c); rf != nil { return rf } } @@ -152,9 +142,6 @@ func (c *Conversation) currentRoomFeatures(ctx context.Context) *RoomFeatures { } } if len(agents) == 0 { - if cfg != nil && cfg.RoomFeatures != nil { - return cfg.RoomFeatures - } return defaultSDKFeatureConfig() } return computeRoomFeaturesForAgents(agents) @@ -253,14 +240,6 @@ func (c *Conversation) StartTurn(ctx context.Context, agent *Agent, source *Sour return newTurn(ctx, c, agent, source) } -// Session returns the session state from the client, if available. -func (c *Conversation) Session() any { - if c.runtime == nil { - return nil - } - return c.runtime.sessionValue() -} - // Context returns the conversation's context. func (c *Conversation) Context() context.Context { return c.ctx diff --git a/sdk/conversation_state_test.go b/sdk/conversation_state_test.go index a5775253..8cdc1830 100644 --- a/sdk/conversation_state_test.go +++ b/sdk/conversation_state_test.go @@ -79,8 +79,10 @@ func TestConversationStateRoundTripCarrierMetadata(t *testing.T) { AgentIDs: []string{"agent-a"}, }, } - if !saveConversationStateToGenericMetadata(&holder, state) { - // Generic metadata intentionally doesn't support the carrier path. + // saveConversationStateToGenericMetadata intentionally returns false here + // because generic metadata doesn't support the carrier path. + if ok := saveConversationStateToGenericMetadata(&holder, state); ok { + t.Fatalf("expected generic metadata save to report unsupported carrier path") } carrier.SetSDKPortalMetadata(&SDKPortalMetadata{Conversation: *state}) loaded, ok := carrier.GetSDKPortalMetadata(), carrier.GetSDKPortalMetadata() != nil diff --git a/sdk/conversation_test.go b/sdk/conversation_test.go index d6800614..1b7697ce 100644 --- a/sdk/conversation_test.go +++ b/sdk/conversation_test.go @@ -25,7 +25,7 @@ func (c testAgentCatalog) ResolveAgent(_ context.Context, _ *bridgev2.UserLogin, return c.byIdentifier[identifier], nil } -func newTestConversation(cfg *Config, state sdkConversationState) *Conversation { +func newTestConversation(cfg *Config[struct{}, *struct{}], state sdkConversationState) *Conversation { return newConversation( context.Background(), &bridgev2.Portal{ @@ -36,12 +36,12 @@ func newTestConversation(cfg *Config, state sdkConversationState) *Conversation }, nil, bridgev2.EventSender{}, - &staticRuntime{cfg: cfg}, + &staticRuntime[struct{}, *struct{}]{cfg: cfg}, ) } func TestConversationCurrentRoomFeaturesUsesConfiguredDefaultAgent(t *testing.T) { - conv := newTestConversation(&Config{ + conv := newTestConversation(&Config[struct{}, *struct{}]{ Agent: &Agent{ ID: "default", Capabilities: AgentCapabilities{ @@ -61,7 +61,7 @@ func TestConversationCurrentRoomFeaturesUsesConfiguredDefaultAgent(t *testing.T) } func TestConversationCurrentRoomFeaturesFallsBackAfterUnresolvedAgents(t *testing.T) { - conv := newTestConversation(&Config{ + conv := newTestConversation(&Config[struct{}, *struct{}]{ Agent: &Agent{ ID: "default", Capabilities: AgentCapabilities{ @@ -83,7 +83,7 @@ func TestConversationCurrentRoomFeaturesFallsBackAfterUnresolvedAgents(t *testin } func TestConversationCurrentRoomFeaturesIgnoresUnresolvedAgentsWhenOneResolves(t *testing.T) { - conv := newTestConversation(&Config{ + conv := newTestConversation(&Config[struct{}, *struct{}]{ AgentCatalog: testAgentCatalog{ byIdentifier: map[string]*Agent{ "found": { diff --git a/sdk/login_handle.go b/sdk/login_handle.go index a91710a8..94feb9b8 100644 --- a/sdk/login_handle.go +++ b/sdk/login_handle.go @@ -70,10 +70,10 @@ func (l *LoginHandle) EnsureConversation(ctx context.Context, spec ConversationS AIRoomKind: conv.aiRoomKind(), ForceCapabilities: true, RefreshExtra: func(ctx context.Context, portal *bridgev2.Portal) { - if l.runtime == nil || l.runtime.config() == nil || len(l.runtime.config().Commands) == 0 { + if l.runtime == nil || len(l.runtime.commands()) == 0 { return } - BroadcastCommandDescriptions(ctx, portal, l.login.Bridge.Bot, l.runtime.config().Commands) + BroadcastCommandDescriptions(ctx, portal, l.login.Bridge.Bot, l.runtime.commands()) }, }) if err != nil { diff --git a/sdk/part_apply_test.go b/sdk/part_apply_test.go index 3ccf9a1c..25c478dd 100644 --- a/sdk/part_apply_test.go +++ b/sdk/part_apply_test.go @@ -8,7 +8,7 @@ import ( ) func newPartApplyTestTurn() *Turn { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{}, nil) + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{}, nil) return conv.StartTurn(context.Background(), &Agent{ID: "agent"}, nil) } diff --git a/sdk/prompt_context.go b/sdk/prompt_context.go deleted file mode 100644 index 109e6f73..00000000 --- a/sdk/prompt_context.go +++ /dev/null @@ -1,557 +0,0 @@ -package sdk - -import ( - "fmt" - "slices" - "strings" - - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/packages/param" - "github.com/openai/openai-go/v3/responses" -) - -// PromptContext is the canonical provider-facing prompt representation. -type PromptContext struct { - SystemPrompt string - DeveloperPrompt string - Messages []PromptMessage -} - -func UserPromptContext(blocks ...PromptBlock) PromptContext { - return PromptContext{ - Messages: []PromptMessage{{ - Role: PromptRoleUser, - Blocks: slices.Clone(blocks), - }}, - } -} - -func PromptContextHasBlockType(ctx PromptContext, kinds ...PromptBlockType) bool { - if len(kinds) == 0 { - return false - } - allowed := make(map[PromptBlockType]struct{}, len(kinds)) - for _, kind := range kinds { - allowed[kind] = struct{}{} - } - for _, msg := range ctx.Messages { - for _, block := range msg.Blocks { - if _, ok := allowed[block.Type]; ok { - return true - } - } - } - return false -} - -// ChatMessagesToPromptContext converts chat-completions-shaped messages into the canonical prompt model. -func ChatMessagesToPromptContext(messages []openai.ChatCompletionMessageParamUnion) PromptContext { - var ctx PromptContext - AppendChatMessagesToPromptContext(&ctx, messages) - return ctx -} - -func AppendChatMessagesToPromptContext(ctx *PromptContext, messages []openai.ChatCompletionMessageParamUnion) { - if ctx == nil { - return - } - for _, msg := range messages { - appendChatMessageToPromptContext(ctx, msg) - } -} - -func appendChatMessageToPromptContext(ctx *PromptContext, msg openai.ChatCompletionMessageParamUnion) { - if ctx == nil { - return - } - switch { - case msg.OfSystem != nil: - AppendPromptText(&ctx.SystemPrompt, extractChatSystemText(msg.OfSystem.Content)) - case msg.OfDeveloper != nil: - AppendPromptText(&ctx.DeveloperPrompt, extractChatDeveloperText(msg.OfDeveloper.Content)) - case msg.OfUser != nil: - ctx.Messages = append(ctx.Messages, promptMessageFromChatUser(msg.OfUser)) - case msg.OfAssistant != nil: - ctx.Messages = append(ctx.Messages, promptMessageFromChatAssistant(msg.OfAssistant)) - case msg.OfTool != nil: - ctx.Messages = append(ctx.Messages, promptMessageFromChatTool(msg.OfTool)) - } -} - -func AppendPromptText(dst *string, text string) { - text = strings.TrimSpace(text) - if text == "" { - return - } - if *dst == "" { - *dst = text - return - } - *dst = strings.TrimSpace(*dst + "\n\n" + text) -} - -func promptMessageFromChatUser(msg *openai.ChatCompletionUserMessageParam) PromptMessage { - pm := PromptMessage{Role: PromptRoleUser} - if msg == nil { - return pm - } - if msg.Content.OfString.Value != "" { - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: msg.Content.OfString.Value, - }) - } - for _, part := range msg.Content.OfArrayOfContentParts { - pm.Blocks = append(pm.Blocks, promptBlockFromChatUserPart(part)...) - } - return pm -} - -func promptBlockFromChatUserPart(part openai.ChatCompletionContentPartUnionParam) []PromptBlock { - switch { - case part.OfText != nil: - return []PromptBlock{{Type: PromptBlockText, Text: part.OfText.Text}} - case part.OfImageURL != nil: - return []PromptBlock{{ - Type: PromptBlockImage, - ImageURL: part.OfImageURL.ImageURL.URL, - MimeType: inferPromptMimeTypeFromDataURL(part.OfImageURL.ImageURL.URL), - }} - case part.OfFile != nil: - return []PromptBlock{{ - Type: PromptBlockFile, - FileB64: part.OfFile.File.FileData.Value, - Filename: part.OfFile.File.Filename.Value, - }} - case part.OfInputAudio != nil: - return []PromptBlock{{ - Type: PromptBlockAudio, - AudioB64: part.OfInputAudio.InputAudio.Data, - AudioFormat: part.OfInputAudio.InputAudio.Format, - }} - default: - return nil - } -} - -func promptMessageFromChatAssistant(msg *openai.ChatCompletionAssistantMessageParam) PromptMessage { - pm := PromptMessage{Role: PromptRoleAssistant} - if msg == nil { - return pm - } - if msg.Content.OfString.Value != "" { - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: msg.Content.OfString.Value, - }) - } - for _, part := range msg.Content.OfArrayOfContentParts { - if part.OfText == nil { - continue - } - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: part.OfText.Text, - }) - } - for _, toolCall := range msg.ToolCalls { - if toolCall.OfFunction == nil { - continue - } - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockToolCall, - ToolCallID: toolCall.OfFunction.ID, - ToolName: toolCall.OfFunction.Function.Name, - ToolCallArguments: toolCall.OfFunction.Function.Arguments, - }) - } - return pm -} - -func promptMessageFromChatTool(msg *openai.ChatCompletionToolMessageParam) PromptMessage { - if msg == nil { - return PromptMessage{Role: PromptRoleToolResult} - } - pm := PromptMessage{ - Role: PromptRoleToolResult, - ToolCallID: msg.ToolCallID, - } - if msg.Content.OfString.Value != "" { - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: msg.Content.OfString.Value, - }) - } - for _, part := range msg.Content.OfArrayOfContentParts { - pm.Blocks = append(pm.Blocks, PromptBlock{ - Type: PromptBlockText, - Text: part.Text, - }) - } - return pm -} - -func extractChatSystemText(content openai.ChatCompletionSystemMessageParamContentUnion) string { - if content.OfString.Value != "" { - return content.OfString.Value - } - return joinChatText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartTextParam) string { - return part.Text - }) -} - -func extractChatDeveloperText(content openai.ChatCompletionDeveloperMessageParamContentUnion) string { - if content.OfString.Value != "" { - return content.OfString.Value - } - return joinChatText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartTextParam) string { - return part.Text - }) -} - -func joinChatText[T any](parts []T, extract func(T) string) string { - var values []string - for _, part := range parts { - if text := strings.TrimSpace(extract(part)); text != "" { - values = append(values, text) - } - } - return strings.Join(values, "\n") -} - -func inferPromptMimeTypeFromDataURL(value string) string { - value = strings.TrimSpace(value) - rest, ok := strings.CutPrefix(value, "data:") - if !ok { - return "" - } - value = rest - idx := strings.Index(value, ";") - if idx <= 0 { - return "" - } - return value[:idx] -} - -func BuildDataURL(mimeType, b64Data string) string { - return fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data) -} - -// resolveBlockImageURL returns the image URL for a prompt block, falling back -// to a base64 data URL when no explicit URL is provided. -func resolveBlockImageURL(block PromptBlock) string { - imageURL := strings.TrimSpace(block.ImageURL) - if imageURL == "" && block.ImageB64 != "" { - mimeType := block.MimeType - if mimeType == "" { - mimeType = "image/jpeg" - } - imageURL = BuildDataURL(mimeType, block.ImageB64) - } - return imageURL -} - -// PromptContextToResponsesInput converts the canonical prompt model into Responses input items. -func PromptContextToResponsesInput(ctx PromptContext) responses.ResponseInputParam { - var result responses.ResponseInputParam - - if strings.TrimSpace(ctx.DeveloperPrompt) != "" { - result = append(result, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleDeveloper, - Content: responses.EasyInputMessageContentUnionParam{ - OfString: openai.String(ctx.DeveloperPrompt), - }, - }, - }) - } - - for _, msg := range ctx.Messages { - switch msg.Role { - case PromptRoleUser: - var contentParts responses.ResponseInputMessageContentListParam - hasMultimodal := false - textContent := "" - - for _, block := range msg.Blocks { - switch block.Type { - case PromptBlockText: - if strings.TrimSpace(block.Text) == "" { - continue - } - if textContent != "" { - textContent += "\n" - } - textContent += block.Text - case PromptBlockImage: - imageURL := resolveBlockImageURL(block) - if imageURL == "" { - continue - } - hasMultimodal = true - contentParts = append(contentParts, responses.ResponseInputContentUnionParam{ - OfInputImage: &responses.ResponseInputImageParam{ - ImageURL: openai.String(imageURL), - Detail: responses.ResponseInputImageDetailAuto, - }, - }) - case PromptBlockFile: - fileData := strings.TrimSpace(block.FileB64) - fileURL := strings.TrimSpace(block.FileURL) - if fileData == "" && fileURL == "" { - continue - } - hasMultimodal = true - fileParam := &responses.ResponseInputFileParam{} - if fileData != "" { - fileParam.FileData = openai.String(fileData) - } - if fileURL != "" { - fileParam.FileURL = openai.String(fileURL) - } - if strings.TrimSpace(block.Filename) != "" { - fileParam.Filename = openai.String(block.Filename) - } - contentParts = append(contentParts, responses.ResponseInputContentUnionParam{ - OfInputFile: fileParam, - }) - case PromptBlockAudio, PromptBlockVideo: - // Unsupported in Responses API; caller should fall back to Chat Completions. - } - } - - if textContent != "" { - textPart := responses.ResponseInputContentUnionParam{ - OfInputText: &responses.ResponseInputTextParam{Text: textContent}, - } - contentParts = append([]responses.ResponseInputContentUnionParam{textPart}, contentParts...) - } - - if hasMultimodal && len(contentParts) > 0 { - result = append(result, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleUser, - Content: responses.EasyInputMessageContentUnionParam{ - OfInputItemContentList: contentParts, - }, - }, - }) - } else if textContent != "" { - result = append(result, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleUser, - Content: responses.EasyInputMessageContentUnionParam{ - OfString: openai.String(textContent), - }, - }, - }) - } - case PromptRoleAssistant: - textParts := make([]string, 0, len(msg.Blocks)) - for _, block := range msg.Blocks { - switch block.Type { - case PromptBlockText: - if strings.TrimSpace(block.Text) != "" { - textParts = append(textParts, block.Text) - } - case PromptBlockToolCall: - callID := strings.TrimSpace(block.ToolCallID) - name := strings.TrimSpace(block.ToolName) - args := strings.TrimSpace(block.ToolCallArguments) - if callID == "" || name == "" { - continue - } - if args == "" { - args = "{}" - } - result = appendAssistantTextItem(result, textParts) - textParts = textParts[:0] - result = append(result, responses.ResponseInputItemParamOfFunctionCall(args, callID, name)) - } - } - result = appendAssistantTextItem(result, textParts) - case PromptRoleToolResult: - callID := strings.TrimSpace(msg.ToolCallID) - output := strings.TrimSpace(msg.Text()) - if callID == "" || output == "" { - continue - } - result = append(result, responses.ResponseInputItemUnionParam{ - OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ - CallID: callID, - Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{ - OfString: openai.String(output), - }, - }, - }) - } - } - - return result -} - -func appendAssistantTextItem(result responses.ResponseInputParam, textParts []string) responses.ResponseInputParam { - text := strings.TrimSpace(strings.Join(textParts, "")) - if text == "" { - return result - } - return append(result, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleAssistant, - Content: responses.EasyInputMessageContentUnionParam{ - OfString: openai.String(text), - }, - }, - }) -} - -// PromptContextToChatCompletionMessages converts the canonical prompt model into Chat Completions messages. -func PromptContextToChatCompletionMessages(ctx PromptContext, supportsVideoURL bool) []openai.ChatCompletionMessageParamUnion { - result := make([]openai.ChatCompletionMessageParamUnion, 0, len(ctx.Messages)+2) - if strings.TrimSpace(ctx.SystemPrompt) != "" { - result = append(result, openai.SystemMessage(ctx.SystemPrompt)) - } - if strings.TrimSpace(ctx.DeveloperPrompt) != "" { - result = append(result, openai.ChatCompletionMessageParamUnion{ - OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ - OfString: openai.String(ctx.DeveloperPrompt), - }, - }, - }) - } - - for _, msg := range ctx.Messages { - switch msg.Role { - case PromptRoleUser: - if promptMessageHasMultimodal(msg) { - result = append(result, openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfArrayOfContentParts: promptBlocksToChatCompletionContentParts(msg.Blocks, supportsVideoURL), - }, - }, - }) - } else { - result = append(result, openai.UserMessage(msg.Text())) - } - case PromptRoleAssistant: - assistant := &openai.ChatCompletionAssistantMessageParam{ - Content: openai.ChatCompletionAssistantMessageParamContentUnion{ - OfString: openai.String(msg.Text()), - }, - } - for _, block := range msg.Blocks { - if block.Type != PromptBlockToolCall { - continue - } - args := strings.TrimSpace(block.ToolCallArguments) - if args == "" { - args = "{}" - } - assistant.ToolCalls = append(assistant.ToolCalls, openai.ChatCompletionMessageToolCallUnionParam{ - OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ - ID: block.ToolCallID, - Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ - Name: block.ToolName, - Arguments: args, - }, - Type: "function", - }, - }) - } - result = append(result, openai.ChatCompletionMessageParamUnion{OfAssistant: assistant}) - case PromptRoleToolResult: - result = append(result, openai.ToolMessage(msg.Text(), msg.ToolCallID)) - } - } - - return result -} - -func promptMessageHasMultimodal(msg PromptMessage) bool { - for _, block := range msg.Blocks { - switch block.Type { - case PromptBlockImage, PromptBlockFile, PromptBlockAudio, PromptBlockVideo: - return true - } - } - return false -} - -func promptBlocksToChatCompletionContentParts(blocks []PromptBlock, supportsVideoURL bool) []openai.ChatCompletionContentPartUnionParam { - result := make([]openai.ChatCompletionContentPartUnionParam, 0, len(blocks)) - for _, block := range blocks { - switch block.Type { - case PromptBlockText: - if strings.TrimSpace(block.Text) == "" { - continue - } - result = append(result, openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{Text: block.Text}, - }) - case PromptBlockImage: - imageURL := resolveBlockImageURL(block) - if imageURL == "" { - continue - } - result = append(result, openai.ChatCompletionContentPartUnionParam{ - OfImageURL: &openai.ChatCompletionContentPartImageParam{ - ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: imageURL}, - }, - }) - case PromptBlockFile: - file := openai.ChatCompletionContentPartFileFileParam{} - if strings.TrimSpace(block.FileB64) != "" { - file.FileData = param.NewOpt(block.FileB64) - } - if strings.TrimSpace(block.Filename) != "" { - file.Filename = param.NewOpt(block.Filename) - } - result = append(result, openai.ChatCompletionContentPartUnionParam{ - OfFile: &openai.ChatCompletionContentPartFileParam{File: file}, - }) - case PromptBlockAudio: - if strings.TrimSpace(block.AudioB64) == "" { - continue - } - format := strings.TrimSpace(block.AudioFormat) - if format == "" { - format = "mp3" - } - result = append(result, openai.ChatCompletionContentPartUnionParam{ - OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{ - InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ - Data: block.AudioB64, - Format: format, - }, - }, - }) - case PromptBlockVideo: - videoURL := strings.TrimSpace(block.VideoURL) - if videoURL == "" && block.VideoB64 != "" { - mimeType := strings.TrimSpace(block.MimeType) - if mimeType == "" { - mimeType = "video/mp4" - } - videoURL = BuildDataURL(mimeType, block.VideoB64) - } - if videoURL == "" { - continue - } - if supportsVideoURL { - result = append(result, param.Override[openai.ChatCompletionContentPartUnionParam](map[string]any{ - "type": "video_url", - "video_url": map[string]any{ - "url": videoURL, - }, - })) - } - } - } - return result -} - -func HasUnsupportedResponsesPromptContext(ctx PromptContext) bool { - return PromptContextHasBlockType(ctx, PromptBlockAudio, PromptBlockVideo) -} diff --git a/sdk/prompt_projection.go b/sdk/prompt_projection.go deleted file mode 100644 index 7096c887..00000000 --- a/sdk/prompt_projection.go +++ /dev/null @@ -1,284 +0,0 @@ -package sdk - -import ( - "encoding/json" - "fmt" - "strings" -) - -type PromptRole string - -const ( - PromptRoleUser PromptRole = "user" - PromptRoleAssistant PromptRole = "assistant" - PromptRoleToolResult PromptRole = "tool_result" -) - -type PromptBlockType string - -const ( - PromptBlockText PromptBlockType = "text" - PromptBlockImage PromptBlockType = "image" - PromptBlockFile PromptBlockType = "file" - PromptBlockThinking PromptBlockType = "thinking" - PromptBlockToolCall PromptBlockType = "tool_call" - PromptBlockAudio PromptBlockType = "audio" - PromptBlockVideo PromptBlockType = "video" -) - -type PromptBlock struct { - Type PromptBlockType - - Text string - - ImageURL string - ImageB64 string - MimeType string - - FileURL string - FileB64 string - Filename string - - ToolCallID string - ToolName string - ToolCallArguments string - - AudioB64 string - AudioFormat string - - VideoURL string - VideoB64 string -} - -type PromptMessage struct { - Role PromptRole - Blocks []PromptBlock - ToolCallID string - ToolName string - IsError bool -} - -func (m PromptMessage) Text() string { - var texts []string - for _, block := range m.Blocks { - switch block.Type { - case PromptBlockText, PromptBlockThinking: - if strings.TrimSpace(block.Text) != "" { - texts = append(texts, block.Text) - } - } - } - return strings.Join(texts, "\n") -} - -func PromptMessagesFromTurnData(td TurnData) []PromptMessage { - if td.Role == "" { - return nil - } - switch td.Role { - case "user": - msg := PromptMessage{Role: PromptRoleUser} - for _, part := range td.Parts { - switch normalizeTurnPartType(part.Type) { - case "text": - if strings.TrimSpace(part.Text) != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) - } - case "image": - if strings.TrimSpace(part.URL) != "" || promptExtraString(part.Extra, "imageB64") != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{ - Type: PromptBlockImage, - ImageURL: part.URL, - ImageB64: promptExtraString(part.Extra, "imageB64"), - MimeType: part.MediaType, - }) - } - case "file": - if strings.TrimSpace(part.URL) != "" || strings.TrimSpace(part.Filename) != "" || promptExtraString(part.Extra, "fileB64") != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{ - Type: PromptBlockFile, - FileURL: part.URL, - FileB64: promptExtraString(part.Extra, "fileB64"), - Filename: part.Filename, - MimeType: part.MediaType, - }) - } - case "audio": - if promptExtraString(part.Extra, "audioB64") != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{ - Type: PromptBlockAudio, - AudioB64: promptExtraString(part.Extra, "audioB64"), - AudioFormat: promptExtraString(part.Extra, "audioFormat"), - MimeType: part.MediaType, - }) - } - case "video": - if strings.TrimSpace(part.URL) != "" || promptExtraString(part.Extra, "videoB64") != "" { - msg.Blocks = append(msg.Blocks, PromptBlock{ - Type: PromptBlockVideo, - VideoURL: part.URL, - VideoB64: promptExtraString(part.Extra, "videoB64"), - MimeType: part.MediaType, - }) - } - } - } - if len(msg.Blocks) == 0 { - return nil - } - return []PromptMessage{msg} - case "assistant": - assistant := PromptMessage{Role: PromptRoleAssistant} - var results []PromptMessage - for _, part := range td.Parts { - switch normalizeTurnPartType(part.Type) { - case "text": - if strings.TrimSpace(part.Text) != "" { - assistant.Blocks = append(assistant.Blocks, PromptBlock{Type: PromptBlockText, Text: part.Text}) - } - case "reasoning": - text := strings.TrimSpace(part.Reasoning) - if text == "" { - text = strings.TrimSpace(part.Text) - } - if text != "" { - assistant.Blocks = append(assistant.Blocks, PromptBlock{Type: PromptBlockThinking, Text: text}) - } - case "tool": - if strings.TrimSpace(part.ToolCallID) != "" && strings.TrimSpace(part.ToolName) != "" { - assistant.Blocks = append(assistant.Blocks, PromptBlock{ - Type: PromptBlockToolCall, - ToolCallID: part.ToolCallID, - ToolName: part.ToolName, - ToolCallArguments: CanonicalToolArguments(part.Input), - }) - } - outputText := strings.TrimSpace(FormatCanonicalValue(part.Output)) - if outputText == "" { - outputText = strings.TrimSpace(part.ErrorText) - } - if outputText == "" && part.State == "output-denied" { - outputText = "Denied by user" - } - if strings.TrimSpace(part.ToolCallID) != "" && outputText != "" { - results = append(results, PromptMessage{ - Role: PromptRoleToolResult, - ToolCallID: part.ToolCallID, - ToolName: part.ToolName, - IsError: strings.TrimSpace(part.ErrorText) != "", - Blocks: []PromptBlock{{ - Type: PromptBlockText, - Text: outputText, - }}, - }) - } - } - } - if len(assistant.Blocks) == 0 && len(results) == 0 { - return nil - } - out := make([]PromptMessage, 0, 1+len(results)) - if len(assistant.Blocks) > 0 { - out = append(out, assistant) - } - out = append(out, results...) - return out - default: - return nil - } -} - -func TurnDataFromUserPromptMessages(messages []PromptMessage) (TurnData, bool) { - if len(messages) == 0 { - return TurnData{}, false - } - msg := messages[0] - if msg.Role != PromptRoleUser { - return TurnData{}, false - } - td := TurnData{Role: "user"} - td.Parts = make([]TurnPart, 0, len(msg.Blocks)) - for _, block := range msg.Blocks { - switch block.Type { - case PromptBlockText: - if strings.TrimSpace(block.Text) != "" { - td.Parts = append(td.Parts, TurnPart{Type: "text", Text: block.Text}) - } - case PromptBlockImage: - if strings.TrimSpace(block.ImageURL) != "" || strings.TrimSpace(block.ImageB64) != "" { - part := TurnPart{Type: "image", URL: block.ImageURL, MediaType: block.MimeType} - if strings.TrimSpace(block.ImageB64) != "" { - part.Extra = map[string]any{"imageB64": block.ImageB64} - } - td.Parts = append(td.Parts, part) - } - case PromptBlockFile: - if strings.TrimSpace(block.FileURL) != "" || strings.TrimSpace(block.FileB64) != "" || strings.TrimSpace(block.Filename) != "" { - part := TurnPart{ - Type: "file", - URL: block.FileURL, - Filename: block.Filename, - MediaType: block.MimeType, - } - if strings.TrimSpace(block.FileB64) != "" { - part.Extra = map[string]any{"fileB64": block.FileB64} - } - td.Parts = append(td.Parts, part) - } - case PromptBlockAudio: - if strings.TrimSpace(block.AudioB64) != "" { - td.Parts = append(td.Parts, TurnPart{ - Type: "audio", - MediaType: block.MimeType, - Extra: map[string]any{ - "audioB64": block.AudioB64, - "audioFormat": block.AudioFormat, - }, - }) - } - case PromptBlockVideo: - if strings.TrimSpace(block.VideoURL) != "" || strings.TrimSpace(block.VideoB64) != "" { - part := TurnPart{ - Type: "video", - URL: block.VideoURL, - MediaType: block.MimeType, - } - if strings.TrimSpace(block.VideoB64) != "" { - part.Extra = map[string]any{"videoB64": block.VideoB64} - } - td.Parts = append(td.Parts, part) - } - } - } - return td, len(td.Parts) > 0 -} - -func CanonicalToolArguments(raw any) string { - if value := strings.TrimSpace(FormatCanonicalValue(raw)); value != "" { - return value - } - return "{}" -} - -func FormatCanonicalValue(raw any) string { - switch typed := raw.(type) { - case nil: - return "" - case string: - return typed - default: - data, err := json.Marshal(typed) - if err != nil { - return fmt.Sprint(typed) - } - return string(data) - } -} - -func promptExtraString(extra map[string]any, key string) string { - if len(extra) == 0 { - return "" - } - value, _ := extra[key].(string) - return strings.TrimSpace(value) -} diff --git a/sdk/runtime.go b/sdk/runtime.go index 1a2c9447..f433244d 100644 --- a/sdk/runtime.go +++ b/sdk/runtime.go @@ -9,36 +9,77 @@ import ( ) type conversationRuntime interface { - config() *Config - sessionValue() any + agent() *Agent + agentCatalog() AgentCatalog + roomFeatures(conv *Conversation) *RoomFeatures + commands() []Command + turnConfig() *TurnConfig conversationStore() *conversationStateStore approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] providerIdentity() ProviderIdentity } -type staticRuntime struct { - cfg *Config - session any +type staticRuntime[SessionT SessionValue, ConfigDataT ConfigValue] struct { + cfg *Config[SessionT, ConfigDataT] + session SessionT login *bridgev2.UserLogin store *conversationStateStore approval *agentremote.ApprovalFlow[*pendingSDKApprovalData] } -func (r *staticRuntime) config() *Config { return r.cfg } +func (r *staticRuntime[SessionT, ConfigDataT]) agent() *Agent { + if r == nil || r.cfg == nil { + return nil + } + return r.cfg.Agent +} + +func (r *staticRuntime[SessionT, ConfigDataT]) agentCatalog() AgentCatalog { + if r == nil || r.cfg == nil { + return nil + } + return r.cfg.AgentCatalog +} -func (r *staticRuntime) sessionValue() any { return r.session } +func (r *staticRuntime[SessionT, ConfigDataT]) roomFeatures(conv *Conversation) *RoomFeatures { + if r == nil || r.cfg == nil { + return nil + } + if r.cfg.GetCapabilities != nil { + if rf := r.cfg.GetCapabilities(r.session, conv); rf != nil { + return rf + } + } + return r.cfg.RoomFeatures +} -func (r *staticRuntime) conversationStore() *conversationStateStore { return r.store } +func (r *staticRuntime[SessionT, ConfigDataT]) commands() []Command { + if r == nil || r.cfg == nil { + return nil + } + return r.cfg.Commands +} + +func (r *staticRuntime[SessionT, ConfigDataT]) turnConfig() *TurnConfig { + if r == nil || r.cfg == nil { + return nil + } + return r.cfg.TurnManagement +} + +func (r *staticRuntime[SessionT, ConfigDataT]) conversationStore() *conversationStateStore { + return r.store +} -func (r *staticRuntime) approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] { +func (r *staticRuntime[SessionT, ConfigDataT]) approvalFlowValue() *agentremote.ApprovalFlow[*pendingSDKApprovalData] { return r.approval } -func (r *staticRuntime) providerIdentity() ProviderIdentity { +func (r *staticRuntime[SessionT, ConfigDataT]) providerIdentity() ProviderIdentity { return resolveProviderIdentity(r.cfg) } -func resolveProviderIdentity(cfg *Config) ProviderIdentity { +func resolveProviderIdentity[SessionT SessionValue, ConfigDataT ConfigValue](cfg *Config[SessionT, ConfigDataT]) ProviderIdentity { if cfg == nil { return normalizedProviderIdentity(ProviderIdentity{}) } @@ -65,8 +106,8 @@ type NewConversationOptions struct { // NewConversation creates an SDK conversation wrapper for provider bridges that // want to drive SDK turns without using the default sdkClient implementation. -func NewConversation(ctx context.Context, login *bridgev2.UserLogin, portal *bridgev2.Portal, sender bridgev2.EventSender, cfg *Config, session any, opts ...NewConversationOptions) *Conversation { - rt := &staticRuntime{ +func NewConversation[SessionT SessionValue, ConfigDataT ConfigValue](ctx context.Context, login *bridgev2.UserLogin, portal *bridgev2.Portal, sender bridgev2.EventSender, cfg *Config[SessionT, ConfigDataT], session SessionT, opts ...NewConversationOptions) *Conversation { + rt := &staticRuntime[SessionT, ConfigDataT]{ cfg: cfg, session: session, login: login, diff --git a/sdk/turn.go b/sdk/turn.go index 4a21c9b9..015d45f4 100644 --- a/sdk/turn.go +++ b/sdk/turn.go @@ -933,10 +933,10 @@ func (t *Turn) ensureDefaultFinalEditPayload(finishReason, fallbackBody string) func (t *Turn) resolvedIdleTimeout() time.Duration { const defaultIdleTimeout = time.Minute - if t == nil || t.conv == nil || t.conv.runtime == nil || t.conv.runtime.config() == nil || t.conv.runtime.config().TurnManagement == nil { + if t == nil || t.conv == nil || t.conv.runtime == nil || t.conv.runtime.turnConfig() == nil { return defaultIdleTimeout } - timeoutMs := t.conv.runtime.config().TurnManagement.IdleTimeoutMs + timeoutMs := t.conv.runtime.turnConfig().IdleTimeoutMs switch { case timeoutMs < 0: return 0 diff --git a/sdk/turn_data_test.go b/sdk/turn_data_test.go index 28bd4c3d..98986651 100644 --- a/sdk/turn_data_test.go +++ b/sdk/turn_data_test.go @@ -109,66 +109,3 @@ func TestBuildTurnDataFromUIMessageMergesRuntimeState(t *testing.T) { t.Fatalf("expected source-url part, got %#v", td.Parts) } } - -func TestPromptMessagesFromTurnData(t *testing.T) { - td := TurnData{ - Role: "assistant", - Parts: []TurnPart{ - {Type: "text", Text: "hello"}, - {Type: "reasoning", Reasoning: "thinking"}, - {Type: "tool", ToolCallID: "tool-1", ToolName: "search", Input: map[string]any{"q": "matrix"}, Output: map[string]any{"done": true}}, - }, - } - - messages := PromptMessagesFromTurnData(td) - if len(messages) != 2 { - t.Fatalf("expected assistant + tool result, got %#v", messages) - } - if messages[0].Role != PromptRoleAssistant { - t.Fatalf("unexpected assistant role %#v", messages[0]) - } - if messages[1].Role != PromptRoleToolResult || messages[1].ToolCallID != "tool-1" { - t.Fatalf("unexpected tool result %#v", messages[1]) - } -} - -func TestTurnDataFromUserPromptMessagesPreservesInlineMedia(t *testing.T) { - messages := []PromptMessage{{ - Role: PromptRoleUser, - Blocks: []PromptBlock{ - {Type: PromptBlockText, Text: "describe these attachments"}, - {Type: PromptBlockImage, ImageB64: "aW1hZ2U=", MimeType: "image/png"}, - {Type: PromptBlockFile, FileB64: "data:application/pdf;base64,cGRm", Filename: "doc.pdf", MimeType: "application/pdf"}, - {Type: PromptBlockAudio, AudioB64: "YXVkaW8=", AudioFormat: "mp3", MimeType: "audio/mpeg"}, - {Type: PromptBlockVideo, VideoB64: "dmlkZW8=", MimeType: "video/mp4"}, - }, - }} - - td, ok := TurnDataFromUserPromptMessages(messages) - if !ok { - t.Fatal("expected user prompt messages to produce turn data") - } - if len(td.Parts) != 5 { - t.Fatalf("expected 5 parts, got %#v", td.Parts) - } - - roundTrip := PromptMessagesFromTurnData(td) - if len(roundTrip) != 1 || len(roundTrip[0].Blocks) != 5 { - t.Fatalf("expected one user message with 5 blocks, got %#v", roundTrip) - } - if got := roundTrip[0].Blocks[1].ImageB64; got != "aW1hZ2U=" { - t.Fatalf("expected inline image to round-trip, got %#v", roundTrip[0].Blocks[1]) - } - if got := roundTrip[0].Blocks[2].FileB64; got != "data:application/pdf;base64,cGRm" { - t.Fatalf("expected inline file to round-trip, got %#v", roundTrip[0].Blocks[2]) - } - if got := roundTrip[0].Blocks[3].AudioB64; got != "YXVkaW8=" { - t.Fatalf("expected inline audio to round-trip, got %#v", roundTrip[0].Blocks[3]) - } - if got := roundTrip[0].Blocks[3].AudioFormat; got != "mp3" { - t.Fatalf("expected audio format to round-trip, got %#v", roundTrip[0].Blocks[3]) - } - if got := roundTrip[0].Blocks[4].VideoB64; got != "dmlkZW8=" { - t.Fatalf("expected inline video to round-trip, got %#v", roundTrip[0].Blocks[4]) - } -} diff --git a/sdk/turn_snapshot.go b/sdk/turn_snapshot.go index 005d5d7b..821c02ce 100644 --- a/sdk/turn_snapshot.go +++ b/sdk/turn_snapshot.go @@ -10,7 +10,6 @@ import ( type TurnSnapshot struct { TurnData TurnData UIMessage map[string]any - PromptMessages []PromptMessage Body string ThinkingContent string ToolCalls []agentremote.ToolCallMetadata @@ -25,7 +24,6 @@ func SnapshotFromTurnData(td TurnData, toolType string) TurnSnapshot { return TurnSnapshot{ TurnData: td.Clone(), UIMessage: UIMessageFromTurnData(td), - PromptMessages: PromptMessagesFromTurnData(td), Body: TurnText(td), ThinkingContent: TurnReasoningText(td), ToolCalls: TurnToolCalls(td, toolType), diff --git a/sdk/turn_test.go b/sdk/turn_test.go index 006e7e2c..be24a5da 100644 --- a/sdk/turn_test.go +++ b/sdk/turn_test.go @@ -192,7 +192,7 @@ func TestTurnRequestApprovalWaitsForResolvedDecision(t *testing.T) { UserMXID: "@owner:test", }, } - runtime := &staticRuntime{ + runtime := &staticRuntime[*struct{}, *struct{}]{ login: login, approval: agentremote.NewApprovalFlow(agentremote.ApprovalFlowConfig[*pendingSDKApprovalData]{ Login: func() *bridgev2.UserLogin { return nil }, @@ -248,7 +248,7 @@ func TestTurnRequestApprovalUsesProvidedApprovalID(t *testing.T) { UserMXID: "@owner:test", }, } - runtime := &staticRuntime{ + runtime := &staticRuntime[*struct{}, *struct{}]{ login: login, approval: agentremote.NewApprovalFlow(agentremote.ApprovalFlowConfig[*pendingSDKApprovalData]{ Login: func() *bridgev2.UserLogin { return nil }, @@ -276,7 +276,7 @@ func TestTurnRequestApprovalUsesProvidedApprovalID(t *testing.T) { } func TestTurnStreamSetTransportReceivesEvents(t *testing.T) { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{}, nil) + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{}, nil) turn := conv.StartTurn(context.Background(), &Agent{ID: "agent"}, nil) var gotTurnID string @@ -685,7 +685,7 @@ func TestTurnBuildFinalEditUsesErrorTextFallback(t *testing.T) { } func TestTurnIdleTimeoutAbortsStuckTurn(t *testing.T) { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{ + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{ TurnManagement: &TurnConfig{IdleTimeoutMs: 20}, }, nil) turn := conv.StartTurn(context.Background(), nil, nil) @@ -704,7 +704,7 @@ func TestTurnIdleTimeoutAbortsStuckTurn(t *testing.T) { } func TestTurnIdleTimeoutResetsOnActivity(t *testing.T) { - conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config{ + conv := NewConversation(context.Background(), nil, nil, bridgev2.EventSender{}, &Config[*struct{}, *struct{}]{ TurnManagement: &TurnConfig{IdleTimeoutMs: 40}, }, nil) turn := conv.StartTurn(context.Background(), nil, nil) diff --git a/sdk/types.go b/sdk/types.go index ecccca1a..4065bb32 100644 --- a/sdk/types.go +++ b/sdk/types.go @@ -215,8 +215,12 @@ type ProviderIdentity struct { StatusNetwork string } +type SessionValue interface{} + +type ConfigValue interface{} + // Config configures the SDK bridge. -type Config struct { +type Config[SessionT SessionValue, ConfigDataT ConfigValue] struct { // Required Name string Description string @@ -229,29 +233,29 @@ type Config struct { // Message handling (required) // session is the value returned by OnConnect; conv is the conversation; // msg is the incoming message; turn is the pre-created Turn for streaming responses. - OnMessage func(session any, conv *Conversation, msg *Message, turn *Turn) error + OnMessage func(session SessionT, conv *Conversation, msg *Message, turn *Turn) error // Event hooks (optional) - OnConnect func(ctx context.Context, login *LoginInfo) (any, error) // returns session state - OnDisconnect func(session any) - OnReaction func(session any, conv *Conversation, reaction *Reaction) error - OnTyping func(session any, conv *Conversation, typing bool) - OnEdit func(session any, conv *Conversation, edit *MessageEdit) error - OnDelete func(session any, conv *Conversation, msgID string) error - OnRoomName func(session any, conv *Conversation, name string) (bool, error) - OnRoomTopic func(session any, conv *Conversation, topic string) (bool, error) + OnConnect func(ctx context.Context, login *LoginInfo) (SessionT, error) // returns session state + OnDisconnect func(session SessionT) + OnReaction func(session SessionT, conv *Conversation, reaction *Reaction) error + OnTyping func(session SessionT, conv *Conversation, typing bool) + OnEdit func(session SessionT, conv *Conversation, edit *MessageEdit) error + OnDelete func(session SessionT, conv *Conversation, msgID string) error + OnRoomName func(session SessionT, conv *Conversation, name string) (bool, error) + OnRoomTopic func(session SessionT, conv *Conversation, topic string) (bool, error) // Turn management (optional) TurnManagement *TurnConfig // Capabilities (optional, dynamic per-conversation) - GetCapabilities func(session any, conv *Conversation) *RoomFeatures + GetCapabilities func(session SessionT, conv *Conversation) *RoomFeatures // Search & chat ops (optional) - SearchUsers func(ctx context.Context, session any, query string) ([]*bridgev2.ResolveIdentifierResponse, error) - GetContactList func(ctx context.Context, session any) ([]*bridgev2.ResolveIdentifierResponse, error) - ResolveIdentifier func(ctx context.Context, session any, id string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) - CreateChat func(ctx context.Context, session any, params *CreateChatParams) (*bridgev2.CreateChatResponse, error) + SearchUsers func(ctx context.Context, session SessionT, query string) ([]*bridgev2.ResolveIdentifierResponse, error) + GetContactList func(ctx context.Context, session SessionT) ([]*bridgev2.ResolveIdentifierResponse, error) + ResolveIdentifier func(ctx context.Context, session SessionT, id string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) + CreateChat func(ctx context.Context, session SessionT, params *CreateChatParams) (*bridgev2.CreateChatResponse, error) DeleteChat func(conv *Conversation) error GetChatInfo func(conv *Conversation) (*bridgev2.ChatInfo, error) GetUserInfo func(ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) @@ -296,6 +300,6 @@ type Config struct { ConfigPath string // default: auto-discover DBMeta func() database.MetaTypes // nil = default ExampleConfig string // YAML - ConfigData any // config struct pointer + ConfigData ConfigDataT // config struct pointer ConfigUpgrader configupgrade.Upgrader }