Skip to content
Merged
12 changes: 8 additions & 4 deletions approval_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"

"github.com/beeper/agentremote/turns"
)

// ApprovalReactionHandler is the interface used by BaseReactionHandler to
Expand Down Expand Up @@ -1600,10 +1598,16 @@ func (f *ApprovalFlow[D]) editPromptToResolvedState(
Decision: decision,
ExpiresAt: prompt.ExpiresAt,
})
edit := turns.BuildConvertedEdit(response.Content, response.TopLevelExtra)
if edit == nil {
if response.Content == nil {
return
}
edit := &bridgev2.ConvertedEdit{
ModifiedParts: []*bridgev2.ConvertedEditPart{{
Type: event.EventMessage,
Content: response.Content,
TopLevelExtra: response.TopLevelExtra,
}},
}
ac.login.QueueRemoteEvent(&RemoteEdit{
Portal: ac.portal.PortalKey,
Sender: ac.sender,
Expand Down
3 changes: 3 additions & 0 deletions bridges/ai/agentstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ func (s *AgentStoreAdapter) GetAgentByID(ctx context.Context, agentID string) (*
func (s *AgentStoreAdapter) GetAgentForRoom(ctx context.Context, meta *PortalMetadata) (*agents.AgentDefinition, error) {
agentID := resolveAgentID(meta)
if agentID == "" {
if !s.client.agentsEnabledForLogin() {
return nil, nil
}
agentID = agents.DefaultAgentID // Default to Beep
}

Expand Down
81 changes: 54 additions & 27 deletions bridges/ai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ const (
ToolNameWebSearch = toolspec.WebSearchName
)

// defaultSimpleModeSystemPrompt is the default system prompt for simple mode rooms.
const defaultSimpleModeSystemPrompt = "You are a helpful assistant."

func hasAssignedAgent(meta *PortalMetadata) bool {
return resolveAgentID(meta) != ""
}
Expand All @@ -50,6 +47,24 @@ func modelRedirectTarget(requested, resolved string) networkid.UserID {
return modelUserID(resolved)
}

func (oc *AIClient) agentsEnabledForLogin() bool {
if oc == nil || oc.UserLogin == nil {
return false
}
return agentsEnabled(loginMetadata(oc.UserLogin))
}

func shouldEnsureDefaultChat(meta *UserLoginMetadata) bool {
if meta == nil {
return false
}
return meta.Agents == nil || *meta.Agents
}

func agentChatsDisabledError() error {
return bridgev2.WrapRespErr(errors.New("agent chats are disabled for this login"), mautrix.MForbidden)
}

// buildAvailableTools returns a list of ToolInfo for all tools based on tool policy.
func (oc *AIClient) buildAvailableTools(meta *PortalMetadata) []ToolInfo {
names := oc.toolNamesForPortal(meta)
Expand Down Expand Up @@ -409,6 +424,9 @@ func (oc *AIClient) CreateChatWithGhost(ctx context.Context, ghost *bridgev2.Gho
return resp.Chat, nil
}
if agentID, ok := parseAgentFromGhostID(ghostID); ok {
if !oc.agentsEnabledForLogin() {
return nil, agentChatsDisabledError()
}
store := NewAgentStoreAdapter(oc)
agent, err := store.GetAgentByID(ctx, agentID)
if err != nil || agent == nil {
Expand All @@ -425,6 +443,9 @@ func (oc *AIClient) CreateChatWithGhost(ctx context.Context, ghost *bridgev2.Gho

// resolveAgentIdentifier resolves an agent to a ghost and optionally creates a chat.
func (oc *AIClient) resolveAgentIdentifier(ctx context.Context, agent *agents.AgentDefinition, modelID string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) {
if !oc.agentsEnabledForLogin() {
return nil, agentChatsDisabledError()
}
explicitModel := modelID != ""
if modelID == "" {
modelID = oc.agentDefaultModel(agent)
Expand Down Expand Up @@ -514,6 +535,9 @@ func modelJoinMember(loginID networkid.UserLoginID, modelID, modelName string, i
}

func (oc *AIClient) createAgentChatWithModel(ctx context.Context, agent *agents.AgentDefinition, modelID string, applyModelOverride bool) (*bridgev2.CreateChatResponse, error) {
if !oc.agentsEnabledForLogin() {
return nil, agentChatsDisabledError()
}
if modelID == "" {
modelID = oc.agentDefaultModel(agent)
}
Expand Down Expand Up @@ -577,7 +601,6 @@ func (oc *AIClient) createNewChat(ctx context.Context, modelID string) (*bridgev
return nil, err
}

// Keep simple mode chats non-agentic by default.
// Rooms created via provisioning (ResolveIdentifier/CreateDM) won't go through our explicit
// post-CreateMatrixRoom call sites. Schedule the welcome notice for when the Matrix room exists.
oc.scheduleWelcomeMessage(ctx, portal.PortalKey)
Expand Down Expand Up @@ -705,7 +728,7 @@ func (oc *AIClient) handleNewChat(
oc.createAndOpenAgentChat(runCtx, portal, agent, modelID, false)
return
}
oc.createAndOpenSimpleChat(runCtx, portal, modelID)
oc.createAndOpenModelChat(runCtx, portal, modelID)
}

func (oc *AIClient) validateNewChatCommand(
Expand All @@ -730,6 +753,9 @@ func (oc *AIClient) resolveNewChatTarget(
if cmd != "agent" {
return nil, "", errors.New(usage)
}
if !oc.agentsEnabledForLogin() {
return nil, "", agentChatsDisabledError()
}
targetID := args[1]
if targetID == "" || len(args) > 2 {
return nil, "", errors.New(usage)
Expand All @@ -753,6 +779,9 @@ func (oc *AIClient) resolveNewChatTarget(
}
agentID := resolveAgentID(meta)
if agentID != "" {
if !oc.agentsEnabledForLogin() {
return nil, "", agentChatsDisabledError()
}
store := NewAgentStoreAdapter(oc)
agent, err := store.GetAgentByID(ctx, agentID)
if err != nil || agent == nil {
Expand Down Expand Up @@ -833,13 +862,15 @@ func (oc *AIClient) createAndOpenAgentChat(ctx context.Context, portal *bridgev2
))
}

func (oc *AIClient) createAndOpenSimpleChat(ctx context.Context, portal *bridgev2.Portal, modelID string) {
newPortal, chatInfo, err := oc.createNewSimpleChat(ctx, modelID)
func (oc *AIClient) createAndOpenModelChat(ctx context.Context, portal *bridgev2.Portal, modelID string) {
chatResp, err := oc.createNewChat(ctx, modelID)
if err != nil {
oc.sendSystemNotice(ctx, portal, "Couldn't create the chat: "+err.Error())
return
}

newPortal := chatResp.Portal
chatInfo := chatResp.PortalInfo
if err := oc.materializePortalRoom(ctx, newPortal, chatInfo, portalRoomMaterializeOptions{SendWelcome: true}); err != nil {
oc.sendSystemNotice(ctx, portal, "Couldn't create the room: "+err.Error())
return
Expand All @@ -852,19 +883,6 @@ func (oc *AIClient) createAndOpenSimpleChat(ctx context.Context, portal *bridgev
))
}

// createNewSimpleChat creates a new simple mode chat portal with the specified model.
func (oc *AIClient) createNewSimpleChat(ctx context.Context, modelID string) (*bridgev2.Portal, *bridgev2.ChatInfo, error) {
portal, chatInfo, err := oc.initPortalForChat(ctx, PortalInitOpts{
ModelID: modelID,
})
if err != nil {
return nil, nil, err
}

// Simple mode rooms are non-agentic. This disables directive processing.
return portal, chatInfo, nil
}

// chatInfoFromPortal builds ChatInfo from an existing portal
func (oc *AIClient) chatInfoFromPortal(ctx context.Context, portal *bridgev2.Portal) *bridgev2.ChatInfo {
meta := portalMeta(portal)
Expand Down Expand Up @@ -981,12 +999,19 @@ func (oc *AIClient) BroadcastRoomState(ctx context.Context, portal *bridgev2.Por
return nil
}

// sendSystemNotice sends an informational notice to the room via the portal pipeline.
// sendSystemNotice sends an informational notice to the room via the bridge bot.
func (oc *AIClient) sendSystemNotice(ctx context.Context, portal *bridgev2.Portal, message string) {
if portal == nil || portal.MXID == "" {
if portal == nil || portal.MXID == "" || oc == nil || oc.UserLogin == nil || oc.UserLogin.Bridge == nil || oc.UserLogin.Bridge.Bot == nil {
return
}
if _, _, err := oc.sendViaPortal(ctx, portal, agentremote.BuildSystemNotice(message), ""); err != nil {
_, err := oc.UserLogin.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{
Parsed: &event.MessageEventContent{
MsgType: event.MsgNotice,
Body: message,
Mentions: &event.Mentions{},
},
}, nil)
if err != nil {
oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to send system notice")
}
}
Expand Down Expand Up @@ -1023,10 +1048,12 @@ func (oc *AIClient) bootstrap(ctx context.Context) {
// Don't return - still create the default chat (matches other bridge patterns)
}

// Create default chat room with Beep agent
if err := oc.ensureDefaultChat(logCtx); err != nil {
oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure default chat")
return
if shouldEnsureDefaultChat(meta) {
// Create default chat room with Beep agent
if err := oc.ensureDefaultChat(logCtx); err != nil {
oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure default chat")
return
}
}

// Mark bootstrap as complete only after successful completion
Expand Down
43 changes: 43 additions & 0 deletions bridges/ai/chat_bootstrap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package ai

import "testing"

func TestShouldEnsureDefaultChat(t *testing.T) {
enabled := true
disabled := false

tests := []struct {
name string
meta *UserLoginMetadata
want bool
}{
{
name: "nil metadata",
meta: nil,
want: false,
},
{
name: "new login with nil agents",
meta: &UserLoginMetadata{},
want: true,
},
{
name: "agents enabled",
meta: &UserLoginMetadata{Agents: &enabled},
want: true,
},
{
name: "agents disabled",
meta: &UserLoginMetadata{Agents: &disabled},
want: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := shouldEnsureDefaultChat(tc.meta); got != tc.want {
t.Fatalf("shouldEnsureDefaultChat() = %v, want %v", got, tc.want)
}
})
}
}
6 changes: 3 additions & 3 deletions bridges/ai/chat_fork_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package ai

import "testing"

func TestCloneForkPortalMetadata_PreservesSimpleMode(t *testing.T) {
func TestCloneForkPortalMetadata_PreservesResolvedModelTarget(t *testing.T) {
src := &PortalMetadata{
ResolvedTarget: &ResolvedTarget{
Kind: ResolvedTargetModel,
Expand All @@ -21,7 +21,7 @@ func TestCloneForkPortalMetadata_PreservesSimpleMode(t *testing.T) {
if got.Title != "Forked Chat" {
t.Fatalf("expected title Forked Chat, got %q", got.Title)
}
if !isSimpleMode(got) {
t.Fatalf("expected forked metadata to keep resolved simple-mode target")
if got.ResolvedTarget == nil || got.ResolvedTarget.Kind != ResolvedTargetModel || got.ResolvedTarget.ModelID != "openai/gpt-5" {
t.Fatalf("expected forked metadata to keep resolved model target, got %#v", got.ResolvedTarget)
}
}
85 changes: 85 additions & 0 deletions bridges/ai/chat_login_redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"slices"
"strings"
"testing"
"time"

"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/database"
)

func TestSearchUsersRequiresLogin(t *testing.T) {
Expand All @@ -29,6 +33,87 @@ func TestGetContactListRequiresLogin(t *testing.T) {
}
}

func TestSearchUsersAndContactsHideAgentsWhenDisabled(t *testing.T) {
enabled := false
oc := &AIClient{
UserLogin: &bridgev2.UserLogin{
UserLogin: &database.UserLogin{
ID: "login-1",
Metadata: &UserLoginMetadata{
Agents: &enabled,
ModelCache: &ModelCache{
Models: []ModelInfo{{
ID: "openai/gpt-5",
Name: "GPT-5",
}},
LastRefresh: time.Now().Unix(),
CacheDuration: 3600,
},
CustomAgents: map[string]*AgentDefinitionContent{
"custom-agent": {
ID: "custom-agent",
Name: "Custom Agent",
Model: "openai/gpt-5",
},
},
},
},
},
connector: &OpenAIConnector{},
}
oc.SetLoggedIn(true)

searchResults, err := oc.SearchUsers(context.Background(), "custom")
if err != nil {
t.Fatalf("SearchUsers returned error: %v", err)
}
if len(searchResults) != 0 {
t.Fatalf("expected agent search results to be hidden, got %#v", searchResults)
}

searchResults, err = oc.SearchUsers(context.Background(), "gpt")
if err != nil {
t.Fatalf("SearchUsers returned error: %v", err)
}
if len(searchResults) != 1 || searchResults[0].UserID != modelUserID("openai/gpt-5") {
t.Fatalf("expected only model search result, got %#v", searchResults)
}

contacts, err := oc.GetContactList(context.Background())
if err != nil {
t.Fatalf("GetContactList returned error: %v", err)
}
if len(contacts) != 1 || contacts[0].UserID != modelUserID("openai/gpt-5") {
t.Fatalf("expected only model contact when agents are disabled, got %#v", contacts)
}
}

func TestCreateChatWithGhostRejectsAgentWhenDisabled(t *testing.T) {
enabled := false
oc := &AIClient{
UserLogin: &bridgev2.UserLogin{
UserLogin: &database.UserLogin{
ID: "login-1",
Metadata: &UserLoginMetadata{
Agents: &enabled,
},
},
},
}

_, err := oc.CreateChatWithGhost(context.Background(), &bridgev2.Ghost{
Ghost: &database.Ghost{
ID: agentUserID("beeper"),
},
})
if err == nil {
t.Fatalf("expected agent ghost chat creation to be rejected")
}
if !strings.Contains(strings.ToLower(err.Error()), "disabled") {
t.Fatalf("expected disabled error, got %v", err)
}
}

func TestModelRedirectTarget(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading
Loading