From 693081745b83a6f20e84d9ada8a5250ed9a3d75c Mon Sep 17 00:00:00 2001 From: David Gageot Date: Tue, 10 Mar 2026 18:37:25 +0100 Subject: [PATCH] refactor: remove duplication in model resolution, thinking budget, and message construction - NewAgentMessage takes a plain string instead of *agent.Agent since only the name was used. - Replace duplicated isThinkingBudgetDisabled functions in provider.go and teamloader.go with a single ThinkingBudget.IsDisabled() method. - Extract resolveModelRefs to deduplicate createProvidersFromInlineAlloy and createProvidersFromAlloyConfig, then inline the trivial wrappers. - Extract ParseModelRef for the strings.Cut(ref, "/") + ModelConfig construction pattern repeated in six call sites. Assisted-By: docker-agent --- pkg/config/latest/model_ref.go | 20 ++++++ pkg/config/latest/types.go | 28 +++++++-- pkg/config/overrides.go | 9 +-- pkg/evaluation/eval.go | 11 +--- pkg/model/provider/override_test.go | 2 +- pkg/model/provider/provider.go | 31 ++------- pkg/runtime/model_switcher.go | 97 +++++++---------------------- pkg/runtime/tool_dispatch.go | 2 +- pkg/session/session.go | 4 +- pkg/session/session_test.go | 34 +++++----- pkg/session/store_test.go | 26 ++++---- pkg/teamloader/teamloader.go | 25 ++------ pkg/teamloader/teamloader_test.go | 2 +- 13 files changed, 115 insertions(+), 176 deletions(-) create mode 100644 pkg/config/latest/model_ref.go diff --git a/pkg/config/latest/model_ref.go b/pkg/config/latest/model_ref.go new file mode 100644 index 000000000..2ab6c3c4c --- /dev/null +++ b/pkg/config/latest/model_ref.go @@ -0,0 +1,20 @@ +package latest + +import ( + "fmt" + "strings" +) + +// ParseModelRef parses an inline "provider/model" reference into a +// ModelConfig. It returns an error when the string does not contain +// exactly one "/" separator or when either part is empty. +// +// cfg, err := ParseModelRef("openai/gpt-4o") +// // cfg.Provider == "openai", cfg.Model == "gpt-4o" +func ParseModelRef(ref string) (ModelConfig, error) { + providerName, model, ok := strings.Cut(ref, "/") + if !ok || providerName == "" || model == "" { + return ModelConfig{}, fmt.Errorf("invalid model reference %q: expected 'provider/model' format", ref) + } + return ModelConfig{Provider: providerName, Model: model}, nil +} diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 54cd3e6b5..7caca6695 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -439,12 +439,12 @@ func (f *FlexibleModelConfig) UnmarshalYAML(unmarshal func(any) error) error { // Try string shorthand first var shorthand string if err := unmarshal(&shorthand); err == nil && shorthand != "" { - provider, model, ok := strings.Cut(shorthand, "/") - if !ok || provider == "" || model == "" { + parsed, parseErr := ParseModelRef(shorthand) + if parseErr != nil { return fmt.Errorf("invalid model shorthand %q: expected format 'provider/model'", shorthand) } - f.Provider = provider - f.Model = model + f.Provider = parsed.Provider + f.Model = parsed.Model return nil } @@ -707,6 +707,26 @@ func (t ThinkingBudget) MarshalYAML() (any, error) { return t.Tokens, nil } +// IsDisabled returns true if the thinking budget is explicitly disabled. +// A nil receiver is treated as "not configured" (not disabled). +// +// Disabled when: +// - Tokens == 0 with no Effort (thinking_budget: 0) +// - Effort == "none" (thinking_budget: none) +// +// NOT disabled when: +// - Tokens > 0 or Tokens == -1 (explicit token budget) +// - Effort is a real level like "medium" or "high" +func (t *ThinkingBudget) IsDisabled() bool { + if t == nil { + return false + } + if t.Tokens == 0 && t.Effort == "" { + return true + } + return t.Effort == "none" +} + // MarshalJSON implements custom marshaling to output simple string or int format // This ensures JSON and YAML have the same flattened format for consistency func (t ThinkingBudget) MarshalJSON() ([]byte, error) { diff --git a/pkg/config/overrides.go b/pkg/config/overrides.go index 7eb1c5785..3bb0afad6 100644 --- a/pkg/config/overrides.go +++ b/pkg/config/overrides.go @@ -186,15 +186,12 @@ func ensureSingleModelExists(cfg *latest.Config, modelName, context string) erro return nil } - providerName, model, ok := strings.Cut(modelName, "/") - if !ok || providerName == "" || model == "" { + parsed, err := latest.ParseModelRef(modelName) + if err != nil { return fmt.Errorf("%s references non-existent model '%s'", context, modelName) } - cfg.Models[modelName] = latest.ModelConfig{ - Provider: providerName, - Model: model, - } + cfg.Models[modelName] = parsed return nil } diff --git a/pkg/evaluation/eval.go b/pkg/evaluation/eval.go index cfcf2a498..509da5489 100644 --- a/pkg/evaluation/eval.go +++ b/pkg/evaluation/eval.go @@ -591,22 +591,17 @@ func createJudgeModel(ctx context.Context, judgeModel string, runConfig *config. return nil, nil } - providerName, model, ok := strings.Cut(judgeModel, "/") - if !ok { + cfg, err := latest.ParseModelRef(judgeModel) + if err != nil { return nil, fmt.Errorf("invalid judge model format %q: expected 'provider/model'", judgeModel) } - cfg := &latest.ModelConfig{ - Provider: providerName, - Model: model, - } - var opts []options.Opt if runConfig.ModelsGateway != "" { opts = append(opts, options.WithGateway(runConfig.ModelsGateway)) } - judge, err := provider.New(ctx, cfg, runConfig.EnvProvider(), opts...) + judge, err := provider.New(ctx, &cfg, runConfig.EnvProvider(), opts...) if err != nil { return nil, fmt.Errorf("creating judge model: %w", err) } diff --git a/pkg/model/provider/override_test.go b/pkg/model/provider/override_test.go index 9232fc095..3da941c7d 100644 --- a/pkg/model/provider/override_test.go +++ b/pkg/model/provider/override_test.go @@ -501,7 +501,7 @@ func TestIsThinkingBudgetDisabled(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tt.expected, isThinkingBudgetDisabled(tt.budget)) + assert.Equal(t, tt.expected, tt.budget.IsDisabled()) }) } } diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index 419897e50..550f6a3eb 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -205,16 +205,11 @@ func createRuleBasedRouter(ctx context.Context, cfg *latest.ModelConfig, models } // Otherwise, treat as an inline model spec (e.g., "openai/gpt-4o") - providerName, model, ok := strings.Cut(modelSpec, "/") - if !ok { + inlineCfg, parseErr := latest.ParseModelRef(modelSpec) + if parseErr != nil { return nil, fmt.Errorf("invalid model spec %q: expected 'provider/model' format or a model reference", modelSpec) } - - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: model, - } - p, err := createDirectProvider(ctx, inlineCfg, env, factoryOpts...) + p, err := createDirectProvider(ctx, &inlineCfg, env, factoryOpts...) if err != nil { return nil, err } @@ -394,7 +389,7 @@ func applyOverrides(cfg *latest.ModelConfig, opts *options.ModelOptions) *latest // 1. ThinkingBudget is nil (not configured) - apply defaults to enable thinking // 2. ThinkingBudget is explicitly disabled (Tokens == 0 or Effort == "none") - clear and re-apply defaults // This allows /think to enable thinking with provider defaults even when config had thinking_budget: 0 - if enhancedCfg.ThinkingBudget == nil || isThinkingBudgetDisabled(enhancedCfg.ThinkingBudget) { + if enhancedCfg.ThinkingBudget == nil || enhancedCfg.ThinkingBudget.IsDisabled() { enhancedCfg.ThinkingBudget = nil applyModelDefaults(&enhancedCfg) slog.Debug("Override: thinking enabled - applied default thinking configuration", @@ -407,22 +402,6 @@ func applyOverrides(cfg *latest.ModelConfig, opts *options.ModelOptions) *latest return &enhancedCfg } -// isThinkingBudgetDisabled returns true if the thinking budget is explicitly disabled. -// NOT disabled when: -// - Tokens > 0 or Tokens == -1 (explicit token budget) -// - Effort is set to something other than "none" (e.g., "medium", "high") -func isThinkingBudgetDisabled(tb *latest.ThinkingBudget) bool { - if tb == nil { - return false - } - if tb.Effort == "none" { - return true - } - // Tokens == 0 with no Effort means explicitly disabled (thinking_budget: 0) - // Tokens == 0 with Effort set (e.g., "medium") means Effort-based config, not disabled - return tb.Tokens == 0 && tb.Effort == "" -} - // applyModelDefaults applies provider-specific default values for model configuration. // These defaults are applied only if the user hasn't explicitly set the values. // @@ -441,7 +420,7 @@ func applyModelDefaults(cfg *latest.ModelConfig) { // If thinking is explicitly disabled (thinking_budget: 0 or thinking_budget: none), // set ThinkingBudget to nil to completely disable thinking. // This ensures no thinking config is sent to the provider. - if isThinkingBudgetDisabled(cfg.ThinkingBudget) { + if cfg.ThinkingBudget.IsDisabled() { cfg.ThinkingBudget = nil slog.Debug("Thinking explicitly disabled via thinking_budget: 0 or none", "provider", cfg.Provider, diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index cd6818712..5aa97fc6e 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -88,7 +88,7 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st modelConfig.Name = modelRef // Check if this is an alloy model (no provider, comma-separated models) if isAlloyModelConfig(modelConfig) { - providers, err := r.createProvidersFromAlloyConfig(ctx, modelConfig) + providers, err := r.resolveModelRefs(ctx, modelConfig.Model) if err != nil { return fmt.Errorf("failed to create alloy model from config: %w", err) } @@ -109,7 +109,7 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st // Check if this is an inline alloy spec (comma-separated provider/model specs) // e.g., "openai/gpt-4o,anthropic/claude-sonnet-4-0" if isInlineAlloySpec(modelRef) { - providers, err := r.createProvidersFromInlineAlloy(ctx, modelRef) + providers, err := r.resolveModelRefs(ctx, modelRef) if err != nil { return fmt.Errorf("failed to create inline alloy model: %w", err) } @@ -146,17 +146,12 @@ func (r *LocalRuntime) resolveModelRef(ctx context.Context, modelRef string) (pr } // Try inline "provider/model" format. - providerName, modelName, ok := strings.Cut(modelRef, "/") - if !ok || providerName == "" || modelName == "" { + inlineCfg, err := latest.ParseModelRef(modelRef) + if err != nil { return nil, fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", modelRef) } - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } - - return r.createProviderFromConfig(ctx, inlineCfg) + return r.createProviderFromConfig(ctx, &inlineCfg) } // isAlloyModelConfig checks if a model config is an alloy model (multiple models). @@ -186,92 +181,44 @@ func isInlineAlloySpec(modelRef string) bool { return validParts >= 2 } -// createProvidersFromInlineAlloy creates providers from an inline alloy spec. -// An inline alloy is comma-separated provider/model specs like "openai/gpt-4o,anthropic/claude-sonnet-4-0". -func (r *LocalRuntime) createProvidersFromInlineAlloy(ctx context.Context, modelRef string) ([]provider.Provider, error) { +// resolveModelRefs resolves a comma-separated list of model references into +// providers. Each reference is first looked up in the config by name; if not +// found it is parsed as an inline "provider/model" spec. +func (r *LocalRuntime) resolveModelRefs(ctx context.Context, commaSeparatedRefs string) ([]provider.Provider, error) { var providers []provider.Provider - for part := range strings.SplitSeq(modelRef, ",") { - part = strings.TrimSpace(part) - if part == "" { + for ref := range strings.SplitSeq(commaSeparatedRefs, ",") { + ref = strings.TrimSpace(ref) + if ref == "" { continue } - // Check if this part exists as a named model in config - if modelCfg, exists := r.modelSwitcherCfg.Models[part]; exists { - modelCfg.Name = part + // Check if this ref exists as a named model in config + if modelCfg, exists := r.modelSwitcherCfg.Models[ref]; exists { + modelCfg.Name = ref prov, err := r.createProviderFromConfig(ctx, &modelCfg) if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", part, err) + return nil, fmt.Errorf("failed to create provider for %q: %w", ref, err) } providers = append(providers, prov) continue } // Parse as provider/model - providerName, modelName, ok := strings.Cut(part, "/") - if !ok { - return nil, fmt.Errorf("invalid model reference %q in inline alloy: expected 'provider/model' format", part) - } - - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } - prov, err := r.createProviderFromConfig(ctx, inlineCfg) - if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", part, err) - } - providers = append(providers, prov) - } - - if len(providers) == 0 { - return nil, errors.New("inline alloy spec has no valid models") - } - - return providers, nil -} - -// createProvidersFromAlloyConfig creates providers for each model in an alloy configuration. -func (r *LocalRuntime) createProvidersFromAlloyConfig(ctx context.Context, alloyCfg latest.ModelConfig) ([]provider.Provider, error) { - var providers []provider.Provider - - for modelRef := range strings.SplitSeq(alloyCfg.Model, ",") { - modelRef = strings.TrimSpace(modelRef) - if modelRef == "" { - continue + inlineCfg, parseErr := latest.ParseModelRef(ref) + if parseErr != nil { + return nil, fmt.Errorf("invalid model reference %q: expected 'provider/model' format or a named model from config", ref) } - // Check if this model reference exists in the config - if modelCfg, exists := r.modelSwitcherCfg.Models[modelRef]; exists { - modelCfg.Name = modelRef - prov, err := r.createProviderFromConfig(ctx, &modelCfg) - if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err) - } - providers = append(providers, prov) - continue - } - - // Try parsing as inline spec (provider/model) - providerName, modelName, ok := strings.Cut(modelRef, "/") - if !ok { - return nil, fmt.Errorf("invalid model reference %q in alloy config: expected 'provider/model' format", modelRef) - } - - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } - prov, err := r.createProviderFromConfig(ctx, inlineCfg) + prov, err := r.createProviderFromConfig(ctx, &inlineCfg) if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err) + return nil, fmt.Errorf("failed to create provider for %q: %w", ref, err) } providers = append(providers, prov) } if len(providers) == 0 { - return nil, errors.New("alloy model config has no valid models") + return nil, errors.New("no valid models found in model reference list") } return providers, nil diff --git a/pkg/runtime/tool_dispatch.go b/pkg/runtime/tool_dispatch.go index 088c4a850..943b2d628 100644 --- a/pkg/runtime/tool_dispatch.go +++ b/pkg/runtime/tool_dispatch.go @@ -424,7 +424,7 @@ func (r *LocalRuntime) runAgentTool(ctx context.Context, handler ToolHandlerFunc } func addAgentMessage(sess *session.Session, a *agent.Agent, msg *chat.Message, events chan Event) { - agentMsg := session.NewAgentMessage(a, msg) + agentMsg := session.NewAgentMessage(a.Name(), msg) sess.AddMessage(agentMsg) events <- MessageAdded(sess.ID, agentMsg, a.Name()) } diff --git a/pkg/session/session.go b/pkg/session/session.go index 20b8730f6..9e6782b5a 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -190,9 +190,9 @@ func UserMessage(content string, multiContent ...chat.MessagePart) *Message { } } -func NewAgentMessage(a *agent.Agent, message *chat.Message) *Message { +func NewAgentMessage(agentName string, message *chat.Message) *Message { return &Message{ - AgentName: a.Name(), + AgentName: agentName, Message: *message, } } diff --git a/pkg/session/session_test.go b/pkg/session/session_test.go index 7ec4ee481..2b802edf0 100644 --- a/pkg/session/session_test.go +++ b/pkg/session/session_test.go @@ -77,12 +77,12 @@ func TestGetMessagesWithToolCalls(t *testing.T) { s := New() - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: "test message", })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleAssistant, Content: "using tool", ToolCalls: []tools.ToolCall{ @@ -92,7 +92,7 @@ func TestGetMessagesWithToolCalls(t *testing.T) { }, })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleTool, ToolCallID: "test-tool", Content: "tool result", @@ -118,22 +118,22 @@ func TestGetMessagesWithSummary(t *testing.T) { s := New() - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: "first message", })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleAssistant, Content: "first response", })) s.Messages = append(s.Messages, Item{Summary: "This is a summary of the conversation so far"}) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: "message after summary", })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleAssistant, Content: "response after summary", })) @@ -221,8 +221,6 @@ func TestGetMessages_CacheControlWithSummary(t *testing.T) { func TestGetLastUserMessages(t *testing.T) { t.Parallel() - testAgent := &agent.Agent{} - t.Run("empty session returns empty slice", func(t *testing.T) { t.Parallel() s := New() @@ -232,7 +230,7 @@ func TestGetLastUserMessages(t *testing.T) { t.Run("session with fewer messages than requested returns all", func(t *testing.T) { t.Parallel() s := New() - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: "Only message", })) @@ -244,23 +242,23 @@ func TestGetLastUserMessages(t *testing.T) { t.Run("session returns last n user messages in order", func(t *testing.T) { t.Parallel() s := New() - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: "First", })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Response 1", })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: "Second", })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Response 2", })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: "Third", })) @@ -274,15 +272,15 @@ func TestGetLastUserMessages(t *testing.T) { t.Run("skips empty user messages", func(t *testing.T) { t.Parallel() s := New() - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: "First", })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: " ", // Empty after trim })) - s.AddMessage(NewAgentMessage(testAgent, &chat.Message{ + s.AddMessage(NewAgentMessage("", &chat.Message{ Role: chat.MessageRoleUser, Content: "Third", })) diff --git a/pkg/session/store_test.go b/pkg/session/store_test.go index c99baf1bd..ba0719e97 100644 --- a/pkg/session/store_test.go +++ b/pkg/session/store_test.go @@ -29,11 +29,11 @@ func TestStoreAgentName(t *testing.T) { ID: "test-session", Messages: []Item{ NewMessageItem(UserMessage("Hello")), - NewMessageItem(NewAgentMessage(testAgent1, &chat.Message{ + NewMessageItem(NewAgentMessage(testAgent1.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Hello from test-agent-1", })), - NewMessageItem(NewAgentMessage(testAgent2, &chat.Message{ + NewMessageItem(NewAgentMessage(testAgent2.Name(), &chat.Message{ Role: chat.MessageRoleUser, Content: "Another message from test-agent-2", })), @@ -82,11 +82,11 @@ func TestStoreMultipleAgents(t *testing.T) { CreatedAt: time.Now(), Messages: []Item{ NewMessageItem(UserMessage("Start conversation")), - NewMessageItem(NewAgentMessage(agent1, &chat.Message{ + NewMessageItem(NewAgentMessage(agent1.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Response from agent 1", })), - NewMessageItem(NewAgentMessage(agent2, &chat.Message{ + NewMessageItem(NewAgentMessage(agent2.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Response from agent 2", })), @@ -128,7 +128,7 @@ func TestGetSessions(t *testing.T) { session1 := &Session{ ID: "session-1", Messages: []Item{ - NewMessageItem(NewAgentMessage(testAgent, &chat.Message{ + NewMessageItem(NewAgentMessage(testAgent.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Message from session 1", })), @@ -139,7 +139,7 @@ func TestGetSessions(t *testing.T) { session2 := &Session{ ID: "session-2", Messages: []Item{ - NewMessageItem(NewAgentMessage(testAgent, &chat.Message{ + NewMessageItem(NewAgentMessage(testAgent.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Message from session 2", })), @@ -180,7 +180,7 @@ func TestGetSessionSummaries(t *testing.T) { ID: "session-1", Title: "First Session", Messages: []Item{ - NewMessageItem(NewAgentMessage(testAgent, &chat.Message{ + NewMessageItem(NewAgentMessage(testAgent.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "A very long message that should not be loaded when getting summaries", })), @@ -192,7 +192,7 @@ func TestGetSessionSummaries(t *testing.T) { ID: "session-2", Title: "Second Session", Messages: []Item{ - NewMessageItem(NewAgentMessage(testAgent, &chat.Message{ + NewMessageItem(NewAgentMessage(testAgent.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Another long message that should not be loaded when getting summaries", })), @@ -236,7 +236,7 @@ func TestBranchSessionCopiesPrefix(t *testing.T) { CreatedAt: time.Now(), Messages: []Item{ NewMessageItem(UserMessage("Hello")), - NewMessageItem(NewAgentMessage(testAgent, &chat.Message{ + NewMessageItem(NewAgentMessage(testAgent.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Response", })), @@ -330,11 +330,11 @@ func TestStoreAgentNameJSON(t *testing.T) { ID: "json-test-session", Messages: []Item{ NewMessageItem(UserMessage("User input")), - NewMessageItem(NewAgentMessage(agent1, &chat.Message{ + NewMessageItem(NewAgentMessage(agent1.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Response from my-agent", })), - NewMessageItem(NewAgentMessage(agent2, &chat.Message{ + NewMessageItem(NewAgentMessage(agent2.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Response from another-agent", })), @@ -401,7 +401,7 @@ func TestUpdateSession_LazyCreation(t *testing.T) { _, err = store.AddMessage(t.Context(), "lazy-session", UserMessage("Hello")) require.NoError(t, err) - _, err = store.AddMessage(t.Context(), "lazy-session", NewAgentMessage(testAgent, &chat.Message{ + _, err = store.AddMessage(t.Context(), "lazy-session", NewAgentMessage(testAgent.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Hi there!", })) @@ -443,7 +443,7 @@ func TestUpdateSession_LazyCreation_InMemory(t *testing.T) { // Add messages via AddMessage _, err = store.AddMessage(t.Context(), "lazy-session", UserMessage("Hello")) require.NoError(t, err) - _, err = store.AddMessage(t.Context(), "lazy-session", NewAgentMessage(testAgent, &chat.Message{ + _, err = store.AddMessage(t.Context(), "lazy-session", NewAgentMessage(testAgent.Name(), &chat.Message{ Role: chat.MessageRoleAssistant, Content: "Hi there!", })) diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index adf6bfffd..d186aee81 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -33,20 +33,6 @@ import ( var defaultMaxTokens int64 = 32000 -// isThinkingBudgetDisabled returns true if the thinking budget is explicitly set to disable thinking -// (e.g., thinking_budget: 0 or thinking_budget: none). -func isThinkingBudgetDisabled(tb *latest.ThinkingBudget) bool { - if tb == nil { - return false - } - // Disabled if tokens is explicitly 0 - if tb.Tokens == 0 && tb.Effort == "" { - return true - } - // Disabled if effort is "none" - return tb.Effort == "none" -} - type loadOptions struct { modelOverrides []string promptFiles []string @@ -316,7 +302,7 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC // Check if thinking_budget was explicitly configured BEFORE provider defaults are applied. // This is used to initialize session thinking state - thinking is only enabled by default // when the user explicitly configured it in their YAML. - if modelCfg.ThinkingBudget != nil && !isThinkingBudgetDisabled(modelCfg.ThinkingBudget) { + if modelCfg.ThinkingBudget != nil && !modelCfg.ThinkingBudget.IsDisabled() { thinkingConfigured = true } @@ -373,14 +359,11 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates modelCfg, exists := cfg.Models[name] if !exists { // Try parsing as inline provider/model format (e.g., "openai/gpt-4o") - providerName, modelName, ok := strings.Cut(name, "/") - if !ok { + parsed, err := latest.ParseModelRef(name) + if err != nil { return nil, fmt.Errorf("fallback model '%s' not found in configuration and is not a valid provider/model format", name) } - modelCfg = latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } + modelCfg = parsed } modelCfg.Name = name diff --git a/pkg/teamloader/teamloader_test.go b/pkg/teamloader/teamloader_test.go index 5b9fdd328..4d1882521 100644 --- a/pkg/teamloader/teamloader_test.go +++ b/pkg/teamloader/teamloader_test.go @@ -269,7 +269,7 @@ func TestIsThinkingBudgetDisabled(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := isThinkingBudgetDisabled(tt.budget) + got := tt.budget.IsDisabled() assert.Equal(t, tt.expected, got) }) }