From ad2a2bc7ea44d35fae349c4380b21e3caab81a24 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Tue, 10 Mar 2026 18:33:25 +0100 Subject: [PATCH] Centralize model-change TUI notification in the main loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a change-detection mechanism (lastEmittedModelID + emitModelInfo closure) in RunStream that automatically emits AgentInfo only when the effective model actually changes. This is checked before and after each LLM call, covering per-tool overrides, fallback, model picker, cooldowns, and any future model-switching feature — without each one having to remember to notify the TUI. - loop.go: replace 3 scattered manual AgentInfo emissions with emitModelInfo calls driven by the closure - model_picker.go: remove AgentInfo emission from the tool handler; rename setModelAndEmitInfo to setCurrentAgentModel (no longer emits) - agent_delegation.go: use getEffectiveModelID instead of getAgentModelID so agent-switch events reflect active fallback cooldowns Assisted-By: docker-agent --- pkg/runtime/agent_delegation.go | 4 +- pkg/runtime/loop.go | 42 ++++++++---- pkg/runtime/model_picker.go | 23 +++---- pkg/runtime/model_picker_test.go | 114 +++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 30 deletions(-) create mode 100644 pkg/runtime/model_picker_test.go diff --git a/pkg/runtime/agent_delegation.go b/pkg/runtime/agent_delegation.go index 2e4513f6b..6cedbc680 100644 --- a/pkg/runtime/agent_delegation.go +++ b/pkg/runtime/agent_delegation.go @@ -168,13 +168,13 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses // Restore original agent info in sidebar if originalAgent, err := r.team.Agent(ca); err == nil { - evts <- AgentInfo(originalAgent.Name(), getAgentModelID(originalAgent), originalAgent.Description(), originalAgent.WelcomeMessage()) + evts <- AgentInfo(originalAgent.Name(), r.getEffectiveModelID(originalAgent), originalAgent.Description(), originalAgent.WelcomeMessage()) } }() // Emit agent info for the new agent if newAgent, err := r.team.Agent(params.Agent); err == nil { - evts <- AgentInfo(newAgent.Name(), getAgentModelID(newAgent), newAgent.Description(), newAgent.WelcomeMessage()) + evts <- AgentInfo(newAgent.Name(), r.getEffectiveModelID(newAgent), newAgent.Description(), newAgent.WelcomeMessage()) } slog.Debug("Creating new session with parent session", "parent_session_id", sess.ID, "tools_approved", sess.ToolsApproved, "thinking", sess.Thinking) diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 638672bb6..9304b69d4 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -91,9 +91,21 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c a := r.resolveSessionAgent(sess) + // lastEmittedModelID tracks what the TUI currently displays. + // emitModelInfo sends an AgentInfo only when the model actually changed, + // so new features (routing, alloy, fallback, model picker, …) never need + // to notify the TUI themselves — the loop handles it. + lastEmittedModelID := r.getEffectiveModelID(a) + emitModelInfo := func(a *agent.Agent, modelID string) { + if modelID == lastEmittedModelID { + return + } + lastEmittedModelID = modelID + events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage()) + } + // Emit agent information for sidebar display - // Use getEffectiveModelID to account for active fallback cooldowns - events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage()) + events <- AgentInfo(a.Name(), lastEmittedModelID, a.Description(), a.WelcomeMessage()) // Emit team information events <- TeamInfo(r.agentDetailsFromTeam(), a.Name()) @@ -241,10 +253,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c modelID := model.ID() - // Notify sidebar when this turn uses a different model (per-tool override). - if modelID != defaultModelID { - events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage()) - } + // Notify sidebar when this turn uses a different model + // (per-tool override, model picker, fallback cooldown, …). + emitModelInfo(a, modelID) slog.Debug("Using agent", "agent", a.Name(), "model", modelID) slog.Debug("Getting model definition", "model_id", modelID) @@ -319,17 +330,15 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c return } - // Update sidebar model info to reflect what was actually used this turn. - // Fallback models are sticky (cooldown system persists them), so we only - // emit once. Per-tool model overrides are temporary (one turn), so we - // emit the override and then revert to the agent's default. + // Update sidebar to reflect the model actually used this turn. + // When no fallback kicked in, revert to the agent's default + // (undoes any temporary per-tool override). + actualModelID := defaultModelID if usedModel != nil && usedModel.ID() != model.ID() { slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID()) - events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage()) - } else if model.ID() != defaultModelID { - // Per-tool override was active: revert sidebar to the agent's default model. - events <- AgentInfo(a.Name(), defaultModelID, a.Description(), a.WelcomeMessage()) + actualModelID = usedModel.ID() } + emitModelInfo(a, actualModelID) streamSpan.SetAttributes( attribute.Int("tool.calls", len(res.Calls)), attribute.Int("content.length", len(res.Content)), @@ -350,6 +359,11 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c r.processToolCalls(ctx, sess, res.Calls, agentTools, events) + // Tool handlers (e.g. change_model, revert_model) may have + // switched the effective model. Notify the TUI now so the + // sidebar updates even when the model stops after the tool call. + emitModelInfo(a, r.getEffectiveModelID(a)) + // Record per-toolset model override for the next LLM turn. toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools) diff --git a/pkg/runtime/model_picker.go b/pkg/runtime/model_picker.go index c0e2cf2ae..f711bff1a 100644 --- a/pkg/runtime/model_picker.go +++ b/pkg/runtime/model_picker.go @@ -30,7 +30,7 @@ func (r *LocalRuntime) findModelPickerTool() *builtin.ModelPickerTool { } // handleChangeModel handles the change_model tool call by switching the current agent's model. -func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) { +func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) { var params builtin.ChangeModelArgs if err := json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms); err != nil { return nil, fmt.Errorf("invalid arguments: %w", err) @@ -53,29 +53,24 @@ func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session )), nil } - return r.setModelAndEmitInfo(ctx, params.Model, events) + return r.setCurrentAgentModel(ctx, params.Model) } // handleRevertModel handles the revert_model tool call by reverting the current agent to its default model. -func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) { - return r.setModelAndEmitInfo(ctx, "", events) +func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) { + return r.setCurrentAgentModel(ctx, "") } -// setModelAndEmitInfo sets the model for the current agent and emits an updated -// AgentInfo event so the UI reflects the change. An empty modelRef reverts to -// the agent's default model. -func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string, events chan Event) (*tools.ToolCallResult, error) { +// setCurrentAgentModel sets the model for the current agent. An empty modelRef +// reverts to the agent's default model. The main loop detects the resulting +// model change and automatically notifies the TUI, so no AgentInfo event is +// emitted here. +func (r *LocalRuntime) setCurrentAgentModel(ctx context.Context, modelRef string) (*tools.ToolCallResult, error) { currentName := r.CurrentAgentName() if err := r.SetAgentModel(ctx, currentName, modelRef); err != nil { return tools.ResultError(fmt.Sprintf("failed to set model: %v", err)), nil } - if a, err := r.team.Agent(currentName); err == nil { - events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage()) - } else { - slog.Warn("Failed to retrieve agent after model change; UI may not reflect the update", "agent", currentName, "error", err) - } - if modelRef == "" { slog.Info("Model reverted via model_picker tool", "agent", currentName) return tools.ResultSuccess("Model reverted to the agent's default model"), nil diff --git a/pkg/runtime/model_picker_test.go b/pkg/runtime/model_picker_test.go new file mode 100644 index 000000000..0050f9067 --- /dev/null +++ b/pkg/runtime/model_picker_test.go @@ -0,0 +1,114 @@ +package runtime + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/tools" +) + +// staticToolSet is a simple ToolSet that returns a fixed list of tools. +type staticToolSet struct { + tools []tools.Tool +} + +func (s *staticToolSet) Tools(context.Context) ([]tools.Tool, error) { + return s.tools, nil +} + +// TestModelChangeEmitsAgentInfo verifies that when a tool call changes the +// agent's model (like change_model does), an AgentInfoEvent with the new +// model ID is emitted even when the model stops in the same turn. +// This is the scenario where the TUI sidebar must be updated. +func TestModelChangeEmitsAgentInfo(t *testing.T) { + newModel := &mockProvider{id: "openai/gpt-4o-mini"} + + // Stream 1: model calls the custom "switch_model" tool and stops. + stream1 := newStreamBuilder(). + AddToolCallName("call_1", "switch_model"). + AddToolCallArguments("call_1", `{}`). + AddStopWithUsage(5, 5). + Build() + + // Stream 2: after the tool result is returned, model says "Done" and stops. + stream2 := newStreamBuilder(). + AddContent("Model switched."). + AddStopWithUsage(5, 5). + Build() + + prov := &queueProvider{id: "test/original-model", streams: []chat.MessageStream{ + stream1, + stream2, + }} + + // Create a toolset that exposes the "switch_model" tool. + switchToolSet := &staticToolSet{tools: []tools.Tool{ + { + Name: "switch_model", + Description: "switch the model", + Annotations: tools.ToolAnnotations{ReadOnlyHint: true}, + }, + }} + + root := agent.New("root", "test agent", + agent.WithModel(prov), + agent.WithToolSets(switchToolSet), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + ) + require.NoError(t, err) + + // Register a custom handler that switches the agent's model override, + // mimicking what handleChangeModel does internally. + rt.toolMap["switch_model"] = func(_ context.Context, _ *session.Session, _ tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) { + a2, _ := rt.team.Agent("root") + a2.SetModelOverride(newModel) + return tools.ResultSuccess("Model changed to openai/gpt-4o-mini"), nil + } + + sess := session.New(session.WithUserMessage("Switch the model"), session.WithToolsApproved(true)) + sess.Title = "Test" + + evCh := rt.RunStream(t.Context(), sess) + var events []Event + for ev := range evCh { + events = append(events, ev) + } + + // Collect all AgentInfoEvents. + var agentInfoEvents []*AgentInfoEvent + for _, ev := range events { + if ai, ok := ev.(*AgentInfoEvent); ok { + agentInfoEvents = append(agentInfoEvents, ai) + } + } + + // There should be at least two AgentInfoEvents: + // 1. The initial one with "test/original-model" + // 2. One after the tool call with "openai/gpt-4o-mini" + require.GreaterOrEqual(t, len(agentInfoEvents), 2, "expected at least 2 AgentInfoEvents, got %d", len(agentInfoEvents)) + + // The first should show the original model. + assert.Equal(t, "test/original-model", agentInfoEvents[0].Model) + + // At least one AgentInfoEvent should show the new model. + foundNewModel := false + for _, ai := range agentInfoEvents { + if ai.Model == "openai/gpt-4o-mini" { + foundNewModel = true + break + } + } + assert.True(t, foundNewModel, "expected an AgentInfoEvent with model 'openai/gpt-4o-mini'") +}