From c1c58e6b32730e154dce9b5b8a5d2bbc868c19a0 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 19 Mar 2026 16:49:47 -0700 Subject: [PATCH] Trim final FRC to match service storage --- .../packages/core/agent_framework/_agents.py | 71 ++++- .../packages/core/agent_framework/_types.py | 3 + ...est_store_final_function_result_content.py | 297 ++++++++++++++++++ 3 files changed, 369 insertions(+), 2 deletions(-) create mode 100644 python/packages/core/tests/core/test_store_final_function_result_content.py diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index c2c6e874f1..d3be7aeb12 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -172,6 +172,7 @@ class _RunContext(TypedDict): tokenizer: TokenizerProtocol | None client_kwargs: Mapping[str, Any] function_invocation_kwargs: Mapping[str, Any] + store_final_function_result_content: bool # region Agent Protocol @@ -940,6 +941,7 @@ async def _run_non_streaming() -> AgentResponse[Any]: agent_name=ctx["agent_name"], session=ctx["session"], session_context=ctx["session_context"], + store_final_function_result_content=ctx["store_final_function_result_content"], ) response_format = ctx["chat_options"].get("response_format") if not ( @@ -989,8 +991,11 @@ async def _post_hook(response: AgentResponse) -> None: # Run after_run providers (reverse order) session_context = ctx["session_context"] + filtered_messages = self._filter_final_function_result_content( + response.messages, ctx["store_final_function_result_content"] + ) session_context._response = AgentResponse( # type: ignore[assignment] - messages=response.messages, + messages=filtered_messages, response_id=response.response_id, ) await self._run_after_providers(session=ctx["session"], context=session_context) @@ -1109,6 +1114,7 @@ async def _prepare_run_context( ) -> _RunContext: opts = dict(options) if options else {} existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {} + store_final_frc: bool = opts.pop("store_final_function_result_content", False) # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) @@ -1234,14 +1240,69 @@ async def _prepare_run_context( "tokenizer": tokenizer or self.tokenizer, "client_kwargs": effective_client_kwargs, "function_invocation_kwargs": additional_function_arguments, + "store_final_function_result_content": store_final_frc, } + @staticmethod + def _filter_final_function_result_content( + response_messages: list[Message], + store_final_function_result_content: bool, + ) -> list[Message]: + """Filter trailing function result messages from response messages. + + Walks backward through the response messages, removing consecutive trailing + messages whose role is ``"tool"`` and whose content is entirely + ``function_result`` type. Messages with mixed content are left unchanged. + The walk stops at the first message that does not match. + + This aligns the behavior of chat history stored via a + :class:`BaseHistoryProvider` with the behavior of agents that store + chat history in the underlying AI service where the final function + result content is never stored. + + Args: + response_messages: The response messages to filter. + store_final_function_result_content: When True, skip filtering + and return messages as-is. + + Returns: + The filtered list of messages, or the original list if nothing was filtered. + """ + if store_final_function_result_content: + return response_messages + + if not response_messages: + return response_messages + + # Walk backward, removing trailing tool-role messages that contain only function_result. + first_kept_index = len(response_messages) + for i in range(len(response_messages) - 1, -1, -1): + message = response_messages[i] + + if message.role != "tool": + break + + all_function_result = len(message.contents) > 0 and all( + content.type == "function_result" for content in message.contents + ) + + if not all_function_result: + break + + first_kept_index = i + + if first_kept_index == len(response_messages): + return response_messages + + return response_messages[:first_kept_index] + async def _finalize_response( self, response: ChatResponse, agent_name: str, session: AgentSession | None, session_context: SessionContext, + store_final_function_result_content: bool = False, ) -> None: """Finalize response by setting author names and running after_run providers. @@ -1250,6 +1311,8 @@ async def _finalize_response( agent_name: The name of the agent to set as author. session: The conversation session. session_context: The invocation context. + store_final_function_result_content: When True, keep trailing + function result content in the stored history. """ # Ensure that the author name is set for each message in the response. for message in response.messages: @@ -1262,9 +1325,13 @@ async def _finalize_response( if session and response.conversation_id and session.service_session_id != response.conversation_id: session.service_session_id = response.conversation_id + filtered_messages = self._filter_final_function_result_content( + response.messages, store_final_function_result_content + ) + # Set the response on the context for after_run providers session_context._response = AgentResponse( # type: ignore[assignment] - messages=response.messages, + messages=filtered_messages, response_id=response.response_id, ) diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index a4e3a57330..345735662a 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -3115,6 +3115,9 @@ class _ChatOptionsBase(TypedDict, total=False): # System/instructions instructions: str + # History storage + store_final_function_result_content: bool + if TYPE_CHECKING: diff --git a/python/packages/core/tests/core/test_store_final_function_result_content.py b/python/packages/core/tests/core/test_store_final_function_result_content.py new file mode 100644 index 0000000000..2d691b5a64 --- /dev/null +++ b/python/packages/core/tests/core/test_store_final_function_result_content.py @@ -0,0 +1,297 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable, Sequence +from typing import Any +from unittest.mock import patch + +from agent_framework import ( + Agent, + AgentResponse, + ChatResponse, + ChatResponseUpdate, + Content, + FunctionInvocationLayer, + InMemoryHistoryProvider, + Message, + ResponseStream, +) +from agent_framework._agents import RawAgent + + +def _make_client( + responses: list[ChatResponse] | None = None, + streaming_responses: list[list[ChatResponseUpdate]] | None = None, +) -> Any: + """Create a mock chat client that supports function invocation.""" + + class _Client(FunctionInvocationLayer): + def __init__(self) -> None: + self.call_count = 0 + self.run_responses: list[ChatResponse] = list(responses or []) + self.streaming_responses: list[list[ChatResponseUpdate]] = list(streaming_responses or []) + + def get_response( + self, + messages: Any, + *, + stream: bool = False, + options: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Any: + options = options or {} + if stream: + return self._stream(options) + return self._non_stream() + + async def _non_stream(self) -> ChatResponse: + self.call_count += 1 + if self.run_responses: + return self.run_responses.pop(0) + return ChatResponse(messages=Message(role="assistant", text="default")) + + def _stream(self, options: dict[str, Any]) -> ResponseStream[ChatResponseUpdate, ChatResponse]: + async def _gen() -> AsyncIterable[ChatResponseUpdate]: + self.call_count += 1 + if self.streaming_responses: + for update in self.streaming_responses.pop(0): + yield update + else: + yield ChatResponseUpdate(contents=[Content.from_text("default")], role="assistant") + + def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: + response_format = options.get("response_format") + output_format_type = response_format if isinstance(response_format, type) else None + return ChatResponse.from_updates(updates, output_format_type=output_format_type) + + return ResponseStream(_gen(), finalizer=_finalize) + + return _Client() + + +def _get_stored_messages(agent: Agent, session: Any) -> list[Message]: + """Retrieve messages stored by the InMemoryHistoryProvider.""" + for provider in agent.context_providers: + if isinstance(provider, InMemoryHistoryProvider): + state = session.state.get(provider.source_id, {}) + return list(state.get("messages", [])) + return [] + + +# -- Unit tests for the static filter method -- + + +def test_filter_returns_original_when_store_is_true() -> None: + messages = [ + Message("assistant", [Content.from_function_call("c1", "get_weather")]), + Message("tool", [Content.from_function_result("c1", result="Sunny")]), + ] + result = RawAgent._filter_final_function_result_content(messages, store_final_function_result_content=True) + assert result is messages + + +def test_filter_returns_original_when_empty() -> None: + result = RawAgent._filter_final_function_result_content([], store_final_function_result_content=False) + assert result == [] + + +def test_filter_removes_trailing_function_result_messages() -> None: + messages = [ + Message("assistant", [Content.from_function_call("c1", "get_weather")]), + Message("tool", [Content.from_function_result("c1", result="Sunny")]), + ] + result = RawAgent._filter_final_function_result_content(messages, store_final_function_result_content=False) + assert len(result) == 1 + assert result[0].role == "assistant" + + +def test_filter_removes_multiple_trailing_function_result_messages() -> None: + messages = [ + Message( + "assistant", + [ + Content.from_function_call("c1", "get_weather"), + Content.from_function_call("c2", "get_news"), + ], + ), + Message("tool", [Content.from_function_result("c1", result="Sunny")]), + Message("tool", [Content.from_function_result("c2", result="Headlines")]), + ] + result = RawAgent._filter_final_function_result_content(messages, store_final_function_result_content=False) + assert len(result) == 1 + assert result[0].role == "assistant" + + +def test_filter_keeps_mixed_content_message() -> None: + messages = [ + Message("tool", [Content.from_text("Some note"), Content.from_function_result("c1", result="Sunny")]), + ] + result = RawAgent._filter_final_function_result_content(messages, store_final_function_result_content=False) + assert len(result) == 1 + assert len(result[0].contents) == 2 + + +def test_filter_no_filtering_when_last_is_not_function_result() -> None: + messages = [ + Message("assistant", [Content.from_text("The weather is sunny.")]), + ] + result = RawAgent._filter_final_function_result_content(messages, store_final_function_result_content=False) + assert len(result) == 1 + assert result[0].contents[0].type == "text" + + +def test_filter_stops_at_non_tool_message() -> None: + messages = [ + Message("assistant", [Content.from_text("Here's the result:")]), + Message("tool", [Content.from_function_result("c1", result="Sunny")]), + ] + result = RawAgent._filter_final_function_result_content(messages, store_final_function_result_content=False) + assert len(result) == 1 + assert result[0].role == "assistant" + assert result[0].contents[0].text == "Here's the result:" + + +# -- Integration tests with the Agent class -- + + +async def test_run_filters_final_function_result_when_default() -> None: + """Default behavior (store_final_function_result_content not set) should filter.""" + response_messages = [ + Message("assistant", [Content.from_function_call("c1", "get_weather")]), + Message("tool", [Content.from_function_result("c1", result="Sunny")]), + ] + client = _make_client(responses=[ChatResponse(messages=response_messages)]) + + with patch("agent_framework._tools.DEFAULT_MAX_ITERATIONS", 2): + agent = Agent(client=client, name="test-agent") + session = agent.create_session() + await agent.run( + [Message("user", [Content.from_text("What's the weather?")])], + session=session, + ) + + stored = _get_stored_messages(agent, session) + assert len(stored) == 2 # user + assistant (tool message filtered) + assert stored[0].role == "user" + assert stored[1].role == "assistant" + assert stored[1].contents[0].type == "function_call" + + +async def test_run_filters_final_function_result_when_false() -> None: + """Explicit False should filter trailing function result content.""" + response_messages = [ + Message("assistant", [Content.from_function_call("c1", "get_weather")]), + Message("tool", [Content.from_function_result("c1", result="Sunny")]), + ] + client = _make_client(responses=[ChatResponse(messages=response_messages)]) + + with patch("agent_framework._tools.DEFAULT_MAX_ITERATIONS", 2): + agent = Agent(client=client, name="test-agent") + session = agent.create_session() + await agent.run( + [Message("user", [Content.from_text("What's the weather?")])], + session=session, + options={"store_final_function_result_content": False}, + ) + + stored = _get_stored_messages(agent, session) + assert len(stored) == 2 # user + assistant (tool message filtered) + assert stored[1].role == "assistant" + assert stored[1].contents[0].type == "function_call" + + +async def test_run_keeps_final_function_result_when_true() -> None: + """When True, trailing function result content should be kept in history.""" + response_messages = [ + Message("assistant", [Content.from_function_call("c1", "get_weather")]), + Message("tool", [Content.from_function_result("c1", result="Sunny")]), + ] + client = _make_client(responses=[ChatResponse(messages=response_messages)]) + + with patch("agent_framework._tools.DEFAULT_MAX_ITERATIONS", 2): + agent = Agent(client=client, name="test-agent") + session = agent.create_session() + await agent.run( + [Message("user", [Content.from_text("What's the weather?")])], + session=session, + options={"store_final_function_result_content": True}, + ) + + stored = _get_stored_messages(agent, session) + assert len(stored) == 3 # user + assistant + tool + assert stored[2].role == "tool" + assert stored[2].contents[0].type == "function_result" + + +async def test_run_no_filtering_when_last_is_text() -> None: + """No filtering when the last message is not a function result.""" + response_messages = [ + Message("assistant", [Content.from_text("The weather is sunny.")]), + ] + client = _make_client(responses=[ChatResponse(messages=response_messages)]) + + with patch("agent_framework._tools.DEFAULT_MAX_ITERATIONS", 2): + agent = Agent(client=client, name="test-agent") + session = agent.create_session() + await agent.run( + [Message("user", [Content.from_text("What's the weather?")])], + session=session, + ) + + stored = _get_stored_messages(agent, session) + assert len(stored) == 2 # user + assistant text + assert stored[1].text == "The weather is sunny." + + +async def test_run_returns_unfiltered_response_to_caller() -> None: + """AgentResponse returned to the caller should contain the full unfiltered response.""" + response_messages = [ + Message("assistant", [Content.from_function_call("c1", "get_weather")]), + Message("tool", [Content.from_function_result("c1", result="Sunny")]), + ] + client = _make_client(responses=[ChatResponse(messages=response_messages)]) + + with patch("agent_framework._tools.DEFAULT_MAX_ITERATIONS", 2): + agent = Agent(client=client, name="test-agent") + session = agent.create_session() + result: AgentResponse = await agent.run( + [Message("user", [Content.from_text("What's the weather?")])], + session=session, + options={"store_final_function_result_content": False}, + ) + + # The returned response should have both messages (unfiltered) + assert len(result.messages) == 2 + assert result.messages[-1].contents[0].type == "function_result" + + +async def test_run_streaming_filters_final_function_result_when_default() -> None: + """Streaming path should also filter trailing function result content by default.""" + streaming_updates = [ + ChatResponseUpdate( + contents=[Content.from_function_call("c1", "get_weather")], + role="assistant", + ), + ChatResponseUpdate( + contents=[Content.from_function_result("c1", result="Sunny")], + role="tool", + ), + ] + client = _make_client(streaming_responses=[streaming_updates]) + + with patch("agent_framework._tools.DEFAULT_MAX_ITERATIONS", 2): + agent = Agent(client=client, name="test-agent") + session = agent.create_session() + stream = agent.run( + [Message("user", [Content.from_text("What's the weather?")])], + session=session, + stream=True, + ) + async for _ in stream: + pass + await stream.get_final_response() + + stored = _get_stored_messages(agent, session) + assert len(stored) == 2 # user + assistant (tool filtered) + assert stored[0].role == "user" + assert stored[1].role == "assistant" + assert stored[1].contents[0].type == "function_call"