diff --git a/pkg/session/session.go b/pkg/session/session.go index bc2e52732..85b6304d9 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -691,7 +691,9 @@ func (s *Session) GetMessages(a *agent.Agent) []chat.Message { // trimMessages ensures we don't exceed the maximum number of messages while maintaining // consistency between assistant messages and their tool call results. -// System messages are always preserved and not counted against the limit. +// System messages and user messages are always preserved and not counted against the limit. +// User messages are protected from trimming to prevent the model from losing +// track of what was asked in long agentic loops. func trimMessages(messages []chat.Message, maxItems int) []chat.Message { // Separate system messages from conversation messages var systemMessages []chat.Message @@ -710,15 +712,27 @@ func trimMessages(messages []chat.Message, maxItems int) []chat.Message { return messages } + // Identify user message indices — these are protected from trimming + protected := make(map[int]bool) + for i, msg := range conversationMessages { + if msg.Role == chat.MessageRoleUser { + protected[i] = true + } + } + // Keep track of tool call IDs that need to be removed toolCallsToRemove := make(map[string]bool) // Calculate how many conversation messages we need to remove toRemove := len(conversationMessages) - maxItems - // Start from the beginning (oldest messages) - for i := range toRemove { - // If this is an assistant message with tool calls, mark them for removal + // Mark the oldest non-protected messages for removal + removed := make(map[int]bool) + for i := 0; i < len(conversationMessages) && len(removed) < toRemove; i++ { + if protected[i] { + continue + } + removed[i] = true if conversationMessages[i].Role == chat.MessageRoleAssistant { for _, toolCall := range conversationMessages[i].ToolCalls { toolCallsToRemove[toolCall.ID] = true @@ -732,11 +746,13 @@ func trimMessages(messages []chat.Message, maxItems int) []chat.Message { // Add all system messages first result = append(result, systemMessages...) - // Add the most recent conversation messages - for i := toRemove; i < len(conversationMessages); i++ { - msg := conversationMessages[i] + // Add protected and non-removed conversation messages + for i, msg := range conversationMessages { + if removed[i] { + continue + } - // Skip tool messages that correspond to removed assistant messages + // Skip orphaned tool results whose assistant message was removed if msg.Role == chat.MessageRoleTool && toolCallsToRemove[msg.ToolCallID] { continue } diff --git a/pkg/session/session_history_test.go b/pkg/session/session_history_test.go index cea57b928..a2bec576d 100644 --- a/pkg/session/session_history_test.go +++ b/pkg/session/session_history_test.go @@ -19,16 +19,20 @@ func TestSessionNumHistoryItems(t *testing.T) { expectedConversationMsgs int }{ { - name: "limit to 3 conversation messages", - numHistoryItems: 3, - messageCount: 10, - expectedConversationMsgs: 3, // Limited to 3 despite 20 total messages + name: "limit to 3 conversation messages — user messages protected", + numHistoryItems: 3, + messageCount: 10, + // 10 user (all protected) + 10 assistant. Need to remove 17, but only 10 removable. + // Result: 10 users + 0 assistants = 10 + expectedConversationMsgs: 10, }, { - name: "limit to 5 conversation messages", - numHistoryItems: 5, - messageCount: 8, - expectedConversationMsgs: 5, // Limited to 5 out of 16 total messages + name: "limit to 5 conversation messages — user messages protected", + numHistoryItems: 5, + messageCount: 8, + // 8 user (all protected) + 8 assistant. Need to remove 11, but only 8 removable. + // Result: 8 users + 0 assistants = 8 + expectedConversationMsgs: 8, }, { name: "fewer messages than limit", @@ -71,9 +75,8 @@ func TestSessionNumHistoryItems(t *testing.T) { // System messages should always be present (at least the instruction) assert.Positive(t, systemCount, "Should have system messages") - // Conversation messages should be limited - assert.LessOrEqual(t, conversationCount, tt.expectedConversationMsgs, - "Conversation messages should not exceed the configured limit") + assert.Equal(t, tt.expectedConversationMsgs, conversationCount, + "Conversation messages should match expected count") }) } } @@ -95,22 +98,20 @@ func TestTrimMessagesPreservesSystemMessages(t *testing.T) { // Count message types systemCount := 0 - conversationCount := 0 + userCount := 0 for _, msg := range trimmed { if msg.Role == chat.MessageRoleSystem { systemCount++ - } else { - conversationCount++ + } + if msg.Role == chat.MessageRoleUser { + userCount++ } } // All system messages should be preserved assert.Equal(t, 3, systemCount, "All system messages should be preserved") - assert.Equal(t, 1, conversationCount, "Should have exactly 1 conversation message") - - // The preserved conversation message should be the most recent - assert.Equal(t, "Assistant response 3", trimmed[len(trimmed)-1].Content, - "Should preserve the most recent conversation message") + // All user messages should be preserved even with maxItems=1 + assert.Equal(t, 3, userCount, "All user messages should be preserved") } func TestTrimMessagesConversationLimit(t *testing.T) { @@ -126,16 +127,22 @@ func TestTrimMessagesConversationLimit(t *testing.T) { {Role: chat.MessageRoleAssistant, Content: "Response 4"}, } + // 8 conversation messages: 4 user + 4 assistant + // User messages are always protected, so only assistant messages can be trimmed. testCases := []struct { limit int - expectedTotal int - expectedConversation int expectedSystem int + expectedUser int + expectedConversation int // total non-system }{ - {limit: 2, expectedTotal: 3, expectedConversation: 2, expectedSystem: 1}, - {limit: 4, expectedTotal: 5, expectedConversation: 4, expectedSystem: 1}, - {limit: 8, expectedTotal: 9, expectedConversation: 8, expectedSystem: 1}, - {limit: 100, expectedTotal: 9, expectedConversation: 8, expectedSystem: 1}, + // limit=2: need to remove 6 of 8, but 4 are protected users → only 4 assistants removable → remove 4 + {limit: 2, expectedSystem: 1, expectedUser: 4, expectedConversation: 4}, + // limit=4: need to remove 4 of 8, 4 are protected → remove all 4 assistants + {limit: 4, expectedSystem: 1, expectedUser: 4, expectedConversation: 4}, + // limit=8: no trimming needed (8 <= 8) + {limit: 8, expectedSystem: 1, expectedUser: 4, expectedConversation: 8}, + // limit=100: no trimming needed + {limit: 100, expectedSystem: 1, expectedUser: 4, expectedConversation: 8}, } for _, tc := range testCases { @@ -143,17 +150,22 @@ func TestTrimMessagesConversationLimit(t *testing.T) { trimmed := trimMessages(messages, tc.limit) systemCount := 0 + userCount := 0 conversationCount := 0 for _, msg := range trimmed { - if msg.Role == chat.MessageRoleSystem { + switch msg.Role { + case chat.MessageRoleSystem: systemCount++ - } else { + case chat.MessageRoleUser: + userCount++ + conversationCount++ + default: conversationCount++ } } - assert.Len(t, trimmed, tc.expectedTotal, "Total message count") assert.Equal(t, tc.expectedSystem, systemCount, "System message count") + assert.Equal(t, tc.expectedUser, userCount, "User messages should always be preserved") assert.Equal(t, tc.expectedConversation, conversationCount, "Conversation message count") }) } @@ -190,7 +202,7 @@ func TestTrimMessagesWithToolCallsPreservation(t *testing.T) { }, } - // Limit to 3 conversation messages (should keep the recent tool interaction) + // Limit to 3 conversation messages trimmed := trimMessages(messages, 3) toolCallIDs := make(map[string]bool) @@ -209,12 +221,113 @@ func TestTrimMessagesWithToolCallsPreservation(t *testing.T) { } } - // Should not have the old tool call - hasOldTool := false + // Both user messages should be preserved + userMessages := 0 for _, msg := range trimmed { - if msg.Role == chat.MessageRoleTool && msg.ToolCallID == "old_tool_1" { - hasOldTool = true + if msg.Role == chat.MessageRoleUser { + userMessages++ } } - assert.False(t, hasOldTool, "Should not have old tool results without their calls") + assert.Equal(t, 2, userMessages, "Both user messages should be preserved") +} + +func TestTrimMessagesPreservesUserMessagesInAgenticLoop(t *testing.T) { + // Simulate a single-turn agentic loop: one user message followed by many tool calls + messages := []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "System prompt"}, + {Role: chat.MessageRoleUser, Content: "Analyze MR #123 and build an integration plan"}, + } + + for i := range 30 { + toolID := fmt.Sprintf("tool_%d", i) + messages = append(messages, chat.Message{ + Role: chat.MessageRoleAssistant, + Content: fmt.Sprintf("Calling tool %d", i), + ToolCalls: []tools.ToolCall{ + {ID: toolID, Function: tools.FunctionCall{Name: "shell"}}, + }, + }, chat.Message{ + Role: chat.MessageRoleTool, + Content: fmt.Sprintf("Tool result %d", i), + ToolCallID: toolID, + }) + } + + // 61 conversation messages (1 user + 30 assistant + 30 tool), limit to 30 + trimmed := trimMessages(messages, 30) + + // The user message must survive + var userMessages []string + for _, msg := range trimmed { + if msg.Role == chat.MessageRoleUser { + userMessages = append(userMessages, msg.Content) + } + } + + assert.Len(t, userMessages, 1, "User message must be preserved") + assert.Equal(t, "Analyze MR #123 and build an integration plan", userMessages[0]) + + // Tool call consistency: every tool result must have a matching assistant tool call + toolCallIDs := make(map[string]bool) + for _, msg := range trimmed { + if msg.Role == chat.MessageRoleAssistant { + for _, tc := range msg.ToolCalls { + toolCallIDs[tc.ID] = true + } + } + } + for _, msg := range trimmed { + if msg.Role == chat.MessageRoleTool { + assert.True(t, toolCallIDs[msg.ToolCallID], + "Tool result %s should have a corresponding assistant tool call", msg.ToolCallID) + } + } +} + +func TestTrimMessagesPreservesAllUserMessages(t *testing.T) { + // Multiple user messages interspersed with tool calls + messages := []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "System prompt"}, + {Role: chat.MessageRoleUser, Content: "First request"}, + } + + for i := range 10 { + toolID := fmt.Sprintf("tool_%d", i) + messages = append(messages, chat.Message{ + Role: chat.MessageRoleAssistant, + ToolCalls: []tools.ToolCall{{ID: toolID}}, + }, chat.Message{ + Role: chat.MessageRoleTool, + Content: fmt.Sprintf("result %d", i), + ToolCallID: toolID, + }) + } + + messages = append(messages, chat.Message{Role: chat.MessageRoleUser, Content: "Follow-up request"}) + + for i := 10; i < 20; i++ { + toolID := fmt.Sprintf("tool_%d", i) + messages = append(messages, chat.Message{ + Role: chat.MessageRoleAssistant, + ToolCalls: []tools.ToolCall{{ID: toolID}}, + }, chat.Message{ + Role: chat.MessageRoleTool, + Content: fmt.Sprintf("result %d", i), + ToolCallID: toolID, + }) + } + + // 42 conversation messages (2 user + 20 assistant + 20 tool), limit to 10 + trimmed := trimMessages(messages, 10) + + var userContents []string + for _, msg := range trimmed { + if msg.Role == chat.MessageRoleUser { + userContents = append(userContents, msg.Content) + } + } + + assert.Len(t, userContents, 2, "Both user messages must be preserved") + assert.Equal(t, "First request", userContents[0]) + assert.Equal(t, "Follow-up request", userContents[1]) } diff --git a/pkg/session/session_test.go b/pkg/session/session_test.go index 1ce354488..630f94597 100644 --- a/pkg/session/session_test.go +++ b/pkg/session/session_test.go @@ -58,9 +58,7 @@ func TestTrimMessagesWithToolCalls(t *testing.T) { result := trimMessages(messages, maxItems) - // Should keep last 3 messages, but ensure tool call consistency - assert.Len(t, result, maxItems) - + // Both user messages are protected, so result includes them plus the most recent assistant/tool pair toolCalls := make(map[string]bool) for _, msg := range result { if msg.Role == chat.MessageRoleAssistant {