Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/runtime/agent_delegation.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses

// Restore original agent info in sidebar
if originalAgent, err := r.team.Agent(ca); err == nil {
evts <- AgentInfo(originalAgent.Name(), getAgentModelID(originalAgent), originalAgent.Description(), originalAgent.WelcomeMessage())
evts <- AgentInfo(originalAgent.Name(), r.getEffectiveModelID(originalAgent), originalAgent.Description(), originalAgent.WelcomeMessage())
}
}()

// Emit agent info for the new agent
if newAgent, err := r.team.Agent(params.Agent); err == nil {
evts <- AgentInfo(newAgent.Name(), getAgentModelID(newAgent), newAgent.Description(), newAgent.WelcomeMessage())
evts <- AgentInfo(newAgent.Name(), r.getEffectiveModelID(newAgent), newAgent.Description(), newAgent.WelcomeMessage())
}

slog.Debug("Creating new session with parent session", "parent_session_id", sess.ID, "tools_approved", sess.ToolsApproved, "thinking", sess.Thinking)
Expand Down
42 changes: 28 additions & 14 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,21 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c

a := r.resolveSessionAgent(sess)

// lastEmittedModelID tracks what the TUI currently displays.
// emitModelInfo sends an AgentInfo only when the model actually changed,
// so new features (routing, alloy, fallback, model picker, …) never need
// to notify the TUI themselves — the loop handles it.
lastEmittedModelID := r.getEffectiveModelID(a)
emitModelInfo := func(a *agent.Agent, modelID string) {
if modelID == lastEmittedModelID {
return
}
lastEmittedModelID = modelID
events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())
}

// Emit agent information for sidebar display
// Use getEffectiveModelID to account for active fallback cooldowns
events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())
events <- AgentInfo(a.Name(), lastEmittedModelID, a.Description(), a.WelcomeMessage())

// Emit team information
events <- TeamInfo(r.agentDetailsFromTeam(), a.Name())
Expand Down Expand Up @@ -241,10 +253,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c

modelID := model.ID()

// Notify sidebar when this turn uses a different model (per-tool override).
if modelID != defaultModelID {
events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())
}
// Notify sidebar when this turn uses a different model
// (per-tool override, model picker, fallback cooldown, …).
emitModelInfo(a, modelID)

slog.Debug("Using agent", "agent", a.Name(), "model", modelID)
slog.Debug("Getting model definition", "model_id", modelID)
Expand Down Expand Up @@ -319,17 +330,15 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
return
}

// Update sidebar model info to reflect what was actually used this turn.
// Fallback models are sticky (cooldown system persists them), so we only
// emit once. Per-tool model overrides are temporary (one turn), so we
// emit the override and then revert to the agent's default.
// Update sidebar to reflect the model actually used this turn.
// When no fallback kicked in, revert to the agent's default
// (undoes any temporary per-tool override).
actualModelID := defaultModelID
if usedModel != nil && usedModel.ID() != model.ID() {
slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID())
events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage())
} else if model.ID() != defaultModelID {
// Per-tool override was active: revert sidebar to the agent's default model.
events <- AgentInfo(a.Name(), defaultModelID, a.Description(), a.WelcomeMessage())
actualModelID = usedModel.ID()
}
emitModelInfo(a, actualModelID)
streamSpan.SetAttributes(
attribute.Int("tool.calls", len(res.Calls)),
attribute.Int("content.length", len(res.Content)),
Expand All @@ -350,6 +359,11 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c

r.processToolCalls(ctx, sess, res.Calls, agentTools, events)

// Tool handlers (e.g. change_model, revert_model) may have
// switched the effective model. Notify the TUI now so the
// sidebar updates even when the model stops after the tool call.
emitModelInfo(a, r.getEffectiveModelID(a))

// Record per-toolset model override for the next LLM turn.
toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools)

Expand Down
23 changes: 9 additions & 14 deletions pkg/runtime/model_picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (r *LocalRuntime) findModelPickerTool() *builtin.ModelPickerTool {
}

// handleChangeModel handles the change_model tool call by switching the current agent's model.
func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) {
func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) {
var params builtin.ChangeModelArgs
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &params); err != nil {
return nil, fmt.Errorf("invalid arguments: %w", err)
Expand All @@ -53,29 +53,24 @@ func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session
)), nil
}

return r.setModelAndEmitInfo(ctx, params.Model, events)
return r.setCurrentAgentModel(ctx, params.Model)
}

// handleRevertModel handles the revert_model tool call by reverting the current agent to its default model.
func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) {
return r.setModelAndEmitInfo(ctx, "", events)
func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) {
return r.setCurrentAgentModel(ctx, "")
}

// setModelAndEmitInfo sets the model for the current agent and emits an updated
// AgentInfo event so the UI reflects the change. An empty modelRef reverts to
// the agent's default model.
func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string, events chan Event) (*tools.ToolCallResult, error) {
// setCurrentAgentModel sets the model for the current agent. An empty modelRef
// reverts to the agent's default model. The main loop detects the resulting
// model change and automatically notifies the TUI, so no AgentInfo event is
// emitted here.
func (r *LocalRuntime) setCurrentAgentModel(ctx context.Context, modelRef string) (*tools.ToolCallResult, error) {
currentName := r.CurrentAgentName()
if err := r.SetAgentModel(ctx, currentName, modelRef); err != nil {
return tools.ResultError(fmt.Sprintf("failed to set model: %v", err)), nil
}

if a, err := r.team.Agent(currentName); err == nil {
events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())
} else {
slog.Warn("Failed to retrieve agent after model change; UI may not reflect the update", "agent", currentName, "error", err)
}

if modelRef == "" {
slog.Info("Model reverted via model_picker tool", "agent", currentName)
return tools.ResultSuccess("Model reverted to the agent's default model"), nil
Expand Down
114 changes: 114 additions & 0 deletions pkg/runtime/model_picker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package runtime

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/agent"
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/team"
"github.com/docker/docker-agent/pkg/tools"
)

// staticToolSet is a simple ToolSet that returns a fixed list of tools.
type staticToolSet struct {
tools []tools.Tool
}

func (s *staticToolSet) Tools(context.Context) ([]tools.Tool, error) {
return s.tools, nil
}

// TestModelChangeEmitsAgentInfo verifies that when a tool call changes the
// agent's model (like change_model does), an AgentInfoEvent with the new
// model ID is emitted even when the model stops in the same turn.
// This is the scenario where the TUI sidebar must be updated.
func TestModelChangeEmitsAgentInfo(t *testing.T) {
newModel := &mockProvider{id: "openai/gpt-4o-mini"}

// Stream 1: model calls the custom "switch_model" tool and stops.
stream1 := newStreamBuilder().
AddToolCallName("call_1", "switch_model").
AddToolCallArguments("call_1", `{}`).
AddStopWithUsage(5, 5).
Build()

// Stream 2: after the tool result is returned, model says "Done" and stops.
stream2 := newStreamBuilder().
AddContent("Model switched.").
AddStopWithUsage(5, 5).
Build()

prov := &queueProvider{id: "test/original-model", streams: []chat.MessageStream{
stream1,
stream2,
}}

// Create a toolset that exposes the "switch_model" tool.
switchToolSet := &staticToolSet{tools: []tools.Tool{
{
Name: "switch_model",
Description: "switch the model",
Annotations: tools.ToolAnnotations{ReadOnlyHint: true},
},
}}

root := agent.New("root", "test agent",
agent.WithModel(prov),
agent.WithToolSets(switchToolSet),
)
tm := team.New(team.WithAgents(root))

rt, err := NewLocalRuntime(tm,
WithSessionCompaction(false),
WithModelStore(mockModelStore{}),
)
require.NoError(t, err)

// Register a custom handler that switches the agent's model override,
// mimicking what handleChangeModel does internally.
rt.toolMap["switch_model"] = func(_ context.Context, _ *session.Session, _ tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) {
a2, _ := rt.team.Agent("root")
a2.SetModelOverride(newModel)
return tools.ResultSuccess("Model changed to openai/gpt-4o-mini"), nil
}

sess := session.New(session.WithUserMessage("Switch the model"), session.WithToolsApproved(true))
sess.Title = "Test"

evCh := rt.RunStream(t.Context(), sess)
var events []Event
for ev := range evCh {
events = append(events, ev)
}

// Collect all AgentInfoEvents.
var agentInfoEvents []*AgentInfoEvent
for _, ev := range events {
if ai, ok := ev.(*AgentInfoEvent); ok {
agentInfoEvents = append(agentInfoEvents, ai)
}
}

// There should be at least two AgentInfoEvents:
// 1. The initial one with "test/original-model"
// 2. One after the tool call with "openai/gpt-4o-mini"
require.GreaterOrEqual(t, len(agentInfoEvents), 2, "expected at least 2 AgentInfoEvents, got %d", len(agentInfoEvents))

// The first should show the original model.
assert.Equal(t, "test/original-model", agentInfoEvents[0].Model)

// At least one AgentInfoEvent should show the new model.
foundNewModel := false
for _, ai := range agentInfoEvents {
if ai.Model == "openai/gpt-4o-mini" {
foundNewModel = true
break
}
}
assert.True(t, foundNewModel, "expected an AgentInfoEvent with model 'openai/gpt-4o-mini'")
}
Loading