diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 8cfc016be..7e08ec6aa 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { ".": "0.9.2" -} \ No newline at end of file +} diff --git a/pyproject.toml b/pyproject.toml index e1b8a23da..aca54aeec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "yaspin>=3.1.0", "claude-agent-sdk>=0.1.0", "anthropic>=0.40.0", + "langgraph-checkpoint>=2.0.0", ] requires-python = ">= 3.12,<4" diff --git a/requirements-dev.lock b/requirements-dev.lock index 1078b30de..d4e9e0768 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -16,12 +16,16 @@ aiohttp==3.13.2 # via agentex-sdk # via httpx-aiohttp # via litellm -aiosignal==1.3.2 +aiosignal==1.4.0 # via aiohttp annotated-types==0.7.0 # via pydantic +anthropic==0.79.0 + # via agentex-sdk anyio==4.10.0 # via agentex-sdk + # via anthropic + # via claude-agent-sdk # via httpx # via mcp # via openai @@ -51,6 +55,8 @@ certifi==2023.7.22 # via requests charset-normalizer==3.4.3 # via requests +claude-agent-sdk==0.1.33 + # via agentex-sdk click==8.2.1 # via litellm # via typer @@ -76,9 +82,12 @@ distlib==0.3.7 # via virtualenv distro==1.9.0 # via agentex-sdk + # via anthropic # via openai # via scale-gp # via scale-gp-beta +docstring-parser==0.17.0 + # via anthropic envier==0.6.1 # via ddtrace execnet==2.1.1 @@ -108,7 +117,9 @@ httpcore==1.0.9 # via httpx httpx==0.27.2 # via agentex-sdk + # via anthropic # via httpx-aiohttp + # via langsmith # via litellm # via mcp # via openai @@ -143,9 +154,14 @@ jinja2==3.1.6 # via agentex-sdk # via litellm jiter==0.10.0 + # via anthropic # via openai json-log-formatter==1.1.1 # via agentex-sdk +jsonpatch==1.33 + # via langchain-core +jsonpointer==3.0.0 + # via jsonpatch jsonref==1.1.0 # via agentex-sdk jsonschema==4.25.0 @@ -161,6 +177,12 @@ jupyter-core==5.8.1 # via jupyter-client kubernetes==28.1.0 # via agentex-sdk +langchain-core==1.2.9 + # via langgraph-checkpoint +langgraph-checkpoint==4.0.0 + # via agentex-sdk +langsmith==0.6.9 + # via langchain-core litellm==1.75.5.post1 # via agentex-sdk markdown-it-py==3.0.0 @@ -172,6 +194,7 @@ matplotlib-inline==0.1.7 # via ipython mcp==1.12.4 # via agentex-sdk + # via claude-agent-sdk # via openai-agents mdurl==0.1.2 # via markdown-it-py @@ -199,15 +222,21 @@ openai-agents==0.4.2 # via agentex-sdk opentelemetry-api==1.37.0 # via ddtrace +orjson==3.11.7 + # via langsmith +ormsgpack==1.12.2 + # via langgraph-checkpoint packaging==23.2 # via huggingface-hub # via ipykernel + # via langchain-core + # via langsmith # via nox # via pytest -pathspec==0.12.1 - # via mypy parso==0.8.4 # via jedi +pathspec==0.12.1 + # via mypy pexpect==4.9.0 # via ipython platformdirs==3.11.0 @@ -237,7 +266,10 @@ pyasn1-modules==0.4.2 # via google-auth pydantic==2.11.9 # via agentex-sdk + # via anthropic # via fastapi + # via langchain-core + # via langsmith # via litellm # via mcp # via openai @@ -283,6 +315,7 @@ pyyaml==6.0.2 # via agentex-sdk # via huggingface-hub # via kubernetes + # via langchain-core pyzmq==27.0.1 # via ipykernel # via jupyter-client @@ -299,12 +332,16 @@ requests==2.32.4 # via datadog # via huggingface-hub # via kubernetes + # via langsmith # via openai-agents # via python-on-whales # via requests-oauthlib + # via requests-toolbelt # via tiktoken requests-oauthlib==2.0.0 # via kubernetes +requests-toolbelt==1.0.0 + # via langsmith respx==0.22.0 rich==13.9.4 # via agentex-sdk @@ -329,6 +366,7 @@ six==1.16.0 # via python-dateutil sniffio==1.3.1 # via agentex-sdk + # via anthropic # via anyio # via httpx # via openai @@ -343,6 +381,10 @@ starlette==0.46.2 # via mcp temporalio==1.18.2 # via agentex-sdk +tenacity==9.1.4 + # via langchain-core +termcolor==3.3.0 + # via yaspin tiktoken==0.11.0 # via litellm time-machine==2.9.0 @@ -374,9 +416,11 @@ types-urllib3==1.26.25.14 typing-extensions==4.12.2 # via agentex-sdk # via aiosignal + # via anthropic # via anyio # via fastapi # via huggingface-hub + # via langchain-core # via mypy # via nexus-rpc # via openai @@ -392,7 +436,6 @@ typing-extensions==4.12.2 # via temporalio # via typer # via typing-inspection - # via virtualenv typing-inspection==0.4.2 # via pydantic # via pydantic-settings @@ -403,6 +446,9 @@ tzlocal==5.3.1 urllib3==1.26.20 # via kubernetes # via requests +uuid-utils==0.14.0 + # via langchain-core + # via langsmith uvicorn==0.35.0 # via agentex-sdk # via mcp @@ -416,7 +462,13 @@ websocket-client==1.8.0 # via kubernetes wrapt==1.17.3 # via ddtrace +xxhash==3.6.0 + # via langsmith yarl==1.20.0 # via aiohttp +yaspin==3.4.0 + # via agentex-sdk zipp==3.23.0 # via importlib-metadata +zstandard==0.25.0 + # via langsmith diff --git a/requirements.lock b/requirements.lock index 79519671e..24601accb 100644 --- a/requirements.lock +++ b/requirements.lock @@ -16,12 +16,16 @@ aiohttp==3.13.2 # via agentex-sdk # via httpx-aiohttp # via litellm -aiosignal==1.3.2 +aiosignal==1.4.0 # via aiohttp annotated-types==0.7.0 # via pydantic +anthropic==0.79.0 + # via agentex-sdk anyio==4.10.0 # via agentex-sdk + # via anthropic + # via claude-agent-sdk # via httpx # via mcp # via openai @@ -49,6 +53,8 @@ certifi==2023.7.22 # via requests charset-normalizer==3.4.3 # via requests +claude-agent-sdk==0.1.33 + # via agentex-sdk click==8.2.1 # via litellm # via typer @@ -69,9 +75,12 @@ decorator==5.2.1 # via ipython distro==1.8.0 # via agentex-sdk + # via anthropic # via openai # via scale-gp # via scale-gp-beta +docstring-parser==0.17.0 + # via anthropic envier==0.6.1 # via ddtrace executing==2.2.0 @@ -98,7 +107,9 @@ httpcore==1.0.9 # via httpx httpx==0.27.2 # via agentex-sdk + # via anthropic # via httpx-aiohttp + # via langsmith # via litellm # via mcp # via openai @@ -132,9 +143,14 @@ jinja2==3.1.6 # via agentex-sdk # via litellm jiter==0.10.0 + # via anthropic # via openai json-log-formatter==1.1.1 # via agentex-sdk +jsonpatch==1.33 + # via langchain-core +jsonpointer==3.0.0 + # via jsonpatch jsonref==1.1.0 # via agentex-sdk jsonschema==4.25.0 @@ -150,6 +166,12 @@ jupyter-core==5.8.1 # via jupyter-client kubernetes==28.1.0 # via agentex-sdk +langchain-core==1.2.9 + # via langgraph-checkpoint +langgraph-checkpoint==4.0.0 + # via agentex-sdk +langsmith==0.6.9 + # via langchain-core litellm==1.75.5.post1 # via agentex-sdk markdown-it-py==4.0.0 @@ -161,6 +183,7 @@ matplotlib-inline==0.1.7 # via ipython mcp==1.12.4 # via agentex-sdk + # via claude-agent-sdk # via openai-agents mdurl==0.1.2 # via markdown-it-py @@ -182,9 +205,15 @@ openai-agents==0.4.2 # via agentex-sdk opentelemetry-api==1.37.0 # via ddtrace +orjson==3.11.7 + # via langsmith +ormsgpack==1.12.2 + # via langgraph-checkpoint packaging==25.0 # via huggingface-hub # via ipykernel + # via langchain-core + # via langsmith # via pytest parso==0.8.4 # via jedi @@ -200,9 +229,6 @@ prompt-toolkit==3.0.51 propcache==0.3.1 # via aiohttp # via yarl -pydantic==2.12.5 - # via agentex-sdk -pydantic-core==2.41.5 protobuf==5.29.5 # via ddtrace # via temporalio @@ -217,6 +243,21 @@ pyasn1==0.6.1 # via rsa pyasn1-modules==0.4.2 # via google-auth +pydantic==2.12.5 + # via agentex-sdk + # via anthropic + # via fastapi + # via langchain-core + # via langsmith + # via litellm + # via mcp + # via openai + # via openai-agents + # via pydantic-settings + # via python-on-whales + # via scale-gp + # via scale-gp-beta +pydantic-core==2.41.5 # via pydantic pydantic-settings==2.10.1 # via mcp @@ -247,6 +288,7 @@ pyyaml==6.0.2 # via agentex-sdk # via huggingface-hub # via kubernetes + # via langchain-core pyzmq==27.0.1 # via ipykernel # via jupyter-client @@ -263,12 +305,16 @@ requests==2.32.4 # via datadog # via huggingface-hub # via kubernetes + # via langsmith # via openai-agents # via python-on-whales # via requests-oauthlib + # via requests-toolbelt # via tiktoken requests-oauthlib==2.0.0 # via kubernetes +requests-toolbelt==1.0.0 + # via langsmith rich==13.9.4 # via agentex-sdk # via typer @@ -290,7 +336,8 @@ six==1.17.0 # via python-dateutil sniffio==1.3.0 # via agentex-sdk -typing-extensions==4.15.0 + # via anthropic + # via anyio # via httpx # via openai # via scale-gp @@ -304,6 +351,10 @@ starlette==0.46.2 # via mcp temporalio==1.18.2 # via agentex-sdk +tenacity==9.1.4 + # via langchain-core +termcolor==3.3.0 + # via yaspin tiktoken==0.11.0 # via litellm tokenizers==0.21.4 @@ -331,11 +382,14 @@ types-requests==2.31.0.6 # via openai-agents types-urllib3==1.26.25.14 # via types-requests +typing-extensions==4.15.0 # via agentex-sdk # via aiosignal + # via anthropic # via anyio # via fastapi # via huggingface-hub + # via langchain-core # via nexus-rpc # via openai # via openai-agents @@ -359,6 +413,9 @@ tzlocal==5.3.1 urllib3==1.26.20 # via kubernetes # via requests +uuid-utils==0.14.0 + # via langchain-core + # via langsmith uvicorn==0.35.0 # via agentex-sdk # via mcp @@ -370,7 +427,13 @@ websocket-client==1.8.0 # via kubernetes wrapt==1.17.3 # via ddtrace +xxhash==3.6.0 + # via langsmith yarl==1.20.0 # via aiohttp +yaspin==3.4.0 + # via agentex-sdk zipp==3.23.0 # via importlib-metadata +zstandard==0.25.0 + # via langsmith diff --git a/src/agentex/lib/adk/__init__.py b/src/agentex/lib/adk/__init__.py index cc4e83db4..42bb4fa37 100644 --- a/src/agentex/lib/adk/__init__.py +++ b/src/agentex/lib/adk/__init__.py @@ -5,6 +5,9 @@ from agentex.lib.adk._modules.acp import ACPModule from agentex.lib.adk._modules.agents import AgentsModule from agentex.lib.adk._modules.agent_task_tracker import AgentTaskTrackerModule +from agentex.lib.adk._modules.checkpointer import create_checkpointer +from agentex.lib.adk._modules._langgraph_tracing import create_langgraph_tracing_handler +from agentex.lib.adk._modules._langgraph_streaming import stream_langgraph_events from agentex.lib.adk._modules.events import EventsModule from agentex.lib.adk._modules.messages import MessagesModule from agentex.lib.adk._modules.state import StateModule @@ -27,16 +30,21 @@ __all__ = [ # Core - "acp", + "acp", "agents", - "tasks", - "messages", - "state", - "streaming", - "tracing", + "tasks", + "messages", + "state", + "streaming", + "tracing", "events", "agent_task_tracker", + # Checkpointing / LangGraph + "create_checkpointer", + "create_langgraph_tracing_handler", + "stream_langgraph_events", + # Providers "providers", # Utils diff --git a/src/agentex/lib/adk/_modules/_http_checkpointer.py b/src/agentex/lib/adk/_modules/_http_checkpointer.py new file mode 100644 index 000000000..ce37cc5f2 --- /dev/null +++ b/src/agentex/lib/adk/_modules/_http_checkpointer.py @@ -0,0 +1,380 @@ +"""HTTP-proxy LangGraph checkpointer. + +Proxies all checkpoint operations through the agentex backend API +instead of connecting directly to PostgreSQL. The backend handles DB +operations through its own connection pool. +""" + +from __future__ import annotations + +import base64 +import random +from typing import Any, cast, override +from collections.abc import Iterator, Sequence, AsyncIterator + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + Checkpoint, + ChannelVersions, + CheckpointTuple, + CheckpointMetadata, + BaseCheckpointSaver, + get_checkpoint_id, + get_serializable_checkpoint_metadata, +) +from langgraph.checkpoint.serde.types import TASKS + +from agentex import AsyncAgentex +from agentex.lib.utils.logging import make_logger + +logger = make_logger(__name__) + + +def _bytes_to_b64(data: bytes | None) -> str | None: + if data is None: + return None + return base64.b64encode(data).decode("ascii") + + +def _b64_to_bytes(data: str | None) -> bytes | None: + if data is None: + return None + return base64.b64decode(data) + + +class HttpCheckpointSaver(BaseCheckpointSaver[str]): + """Checkpoint saver that proxies operations through the agentex HTTP API.""" + + def __init__(self, client: AsyncAgentex) -> None: + super().__init__() + self._http = client._client # noqa: SLF001 # raw httpx.AsyncClient for direct HTTP calls + + async def _post(self, path: str, body: dict[str, Any]) -> Any: + """POST JSON to the backend and return parsed response.""" + response = await self._http.post( + f"/checkpoints{path}", + json=body, + ) + response.raise_for_status() + # put-writes and delete-thread return 204 No Content (no JSON body) + if response.status_code == 204: + return None + return response.json() + + # ── get_next_version (same as BasePostgresSaver) ── + + @override + def get_next_version(self, current: str | None, channel: None) -> str: # type: ignore[override] # noqa: ARG002 + if current is None: + current_v = 0 + elif isinstance(current, int): + current_v = current + else: + current_v = int(current.split(".")[0]) + next_v = current_v + 1 + next_h = random.random() # noqa: S311 + return f"{next_v:032}.{next_h:016}" + + # ── async interface ── + + @override + async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None: + configurable = config["configurable"] # type: ignore[reportTypedDictNotRequiredAccess] + thread_id = configurable["thread_id"] + checkpoint_ns = configurable.get("checkpoint_ns", "") + checkpoint_id = get_checkpoint_id(config) + + data = await self._post( + "/get-tuple", + { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + }, + ) + + if data is None: + return None + + # Reconstruct channel_values from blobs + inline values + checkpoint = data["checkpoint"] + channel_values: dict[str, Any] = {} + + # Inline primitive values already in the checkpoint + if "channel_values" in checkpoint and checkpoint["channel_values"]: + channel_values.update(checkpoint["channel_values"]) + + # Deserialize blobs + for blob in data.get("blobs", []): + blob_type = blob["type"] + if blob_type == "empty": + continue + blob_bytes = _b64_to_bytes(blob.get("blob")) + channel_values[blob["channel"]] = self.serde.loads_typed((blob_type, blob_bytes)) + + checkpoint["channel_values"] = channel_values + + # Handle pending_sends migration for v < 4 + if checkpoint.get("v", 0) < 4 and data.get("parent_checkpoint_id"): + # The backend already returns all writes; filter for TASKS channel sends + pending_sends_raw = [w for w in data.get("pending_writes", []) if w["channel"] == TASKS] + if pending_sends_raw: + sends = [ + self.serde.loads_typed((w["type"], _b64_to_bytes(w["blob"]))) + for w in pending_sends_raw + if w.get("type") + ] + if sends: + enc, blob_data = self.serde.dumps_typed(sends) + channel_values[TASKS] = self.serde.loads_typed((enc, blob_data)) + if checkpoint.get("channel_versions") is None: + checkpoint["channel_versions"] = {} + checkpoint["channel_versions"][TASKS] = ( + max(checkpoint["channel_versions"].values()) + if checkpoint["channel_versions"] + else self.get_next_version(None, None) + ) + + # Reconstruct pending writes + pending_writes: list[tuple[str, str, Any]] = [] + for w in data.get("pending_writes", []): + w_type = w.get("type") + w_bytes = _b64_to_bytes(w.get("blob")) + pending_writes.append( + ( + w["task_id"], + w["channel"], + self.serde.loads_typed((w_type, w_bytes)) if w_type else w_bytes, + ) + ) + + parent_config: RunnableConfig | None = None + if data.get("parent_checkpoint_id"): + parent_config = { + "configurable": { + "thread_id": data["thread_id"], + "checkpoint_ns": data["checkpoint_ns"], + "checkpoint_id": data["parent_checkpoint_id"], + } + } + + return CheckpointTuple( + config={ + "configurable": { + "thread_id": data["thread_id"], + "checkpoint_ns": data["checkpoint_ns"], + "checkpoint_id": data["checkpoint_id"], + } + }, + checkpoint=checkpoint, + metadata=data["metadata"], + parent_config=parent_config, + pending_writes=pending_writes, + ) + + @override + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + configurable = config["configurable"].copy() # type: ignore[reportTypedDictNotRequiredAccess] + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + checkpoint_id = configurable.pop("checkpoint_id", None) + + # Separate inline values from blobs (same logic as AsyncPostgresSaver) + copy = checkpoint.copy() + copy["channel_values"] = copy["channel_values"].copy() + blob_values: dict[str, Any] = {} + for k, v in checkpoint["channel_values"].items(): + if v is None or isinstance(v, (str, int, float, bool)): + pass + else: + blob_values[k] = copy["channel_values"].pop(k) + + # Serialize blob values + blobs: list[dict[str, Any]] = [] + for k, ver in new_versions.items(): + if k in blob_values: + enc, data = self.serde.dumps_typed(blob_values[k]) + blobs.append( + { + "channel": k, + "version": cast(str, ver), + "type": enc, + "blob": _bytes_to_b64(data), + } + ) + else: + blobs.append( + { + "channel": k, + "version": cast(str, ver), + "type": "empty", + "blob": None, + } + ) + + await self._post( + "/put", + { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + "parent_checkpoint_id": checkpoint_id, + "checkpoint": copy, + "metadata": get_serializable_checkpoint_metadata(config, metadata), + "blobs": blobs, + }, + ) + + return { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + @override + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + configurable = config["configurable"] # type: ignore[reportTypedDictNotRequiredAccess] + thread_id = configurable["thread_id"] + checkpoint_ns = configurable["checkpoint_ns"] + checkpoint_id = configurable["checkpoint_id"] + + upsert = all(w[0] in WRITES_IDX_MAP for w in writes) + + serialized_writes: list[dict[str, Any]] = [] + for idx, (channel, value) in enumerate(writes): + enc, data = self.serde.dumps_typed(value) + serialized_writes.append( + { + "task_id": task_id, + "idx": WRITES_IDX_MAP.get(channel, idx), + "channel": channel, + "type": enc, + "blob": _bytes_to_b64(data), + "task_path": task_path, + } + ) + + await self._post( + "/put-writes", + { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + "writes": serialized_writes, + "upsert": upsert, + }, + ) + + @override + async def alist( + self, + config: RunnableConfig | None, + *, + filter: dict[str, Any] | None = None, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> AsyncIterator[CheckpointTuple]: + body: dict[str, Any] = {} + if config: + configurable = config["configurable"] # type: ignore[reportTypedDictNotRequiredAccess] + body["thread_id"] = configurable["thread_id"] + checkpoint_ns = configurable.get("checkpoint_ns") + if checkpoint_ns is not None: + body["checkpoint_ns"] = checkpoint_ns + if filter: + body["filter_metadata"] = filter + if before: + body["before_checkpoint_id"] = get_checkpoint_id(before) + if limit is not None: + body["limit"] = limit + + results = await self._post("/list", body) + + for item in results or []: + # For each listed checkpoint, reconstruct a CheckpointTuple + # with inline channel_values only (blobs not included in list) + checkpoint = item["checkpoint"] + parent_config: RunnableConfig | None = None + if item.get("parent_checkpoint_id"): + parent_config = { + "configurable": { + "thread_id": item["thread_id"], + "checkpoint_ns": item["checkpoint_ns"], + "checkpoint_id": item["parent_checkpoint_id"], + } + } + yield CheckpointTuple( + config={ + "configurable": { + "thread_id": item["thread_id"], + "checkpoint_ns": item["checkpoint_ns"], + "checkpoint_id": item["checkpoint_id"], + } + }, + checkpoint=checkpoint, + metadata=item["metadata"], + parent_config=parent_config, + pending_writes=None, + ) + + @override + async def adelete_thread(self, thread_id: str) -> None: + await self._post("/delete-thread", {"thread_id": thread_id}) + + # ── sync stubs (required by BaseCheckpointSaver) ── + # LangGraph always calls the async methods (aget_tuple, aput, etc.). + # Sync methods are only required by the abstract base class. + + @override + def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None: + raise NotImplementedError("Use aget_tuple() instead.") + + @override + def list( + self, + config: RunnableConfig | None, + *, + filter: dict[str, Any] | None = None, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> Iterator[CheckpointTuple]: + raise NotImplementedError("Use alist() instead.") + + @override + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + raise NotImplementedError("Use aput() instead.") + + @override + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + raise NotImplementedError("Use aput_writes() instead.") + + @override + def delete_thread(self, thread_id: str) -> None: + raise NotImplementedError("Use adelete_thread() instead.") diff --git a/src/agentex/lib/adk/_modules/_langgraph_streaming.py b/src/agentex/lib/adk/_modules/_langgraph_streaming.py new file mode 100644 index 000000000..f58fe6a51 --- /dev/null +++ b/src/agentex/lib/adk/_modules/_langgraph_streaming.py @@ -0,0 +1,202 @@ +"""Async LangGraph streaming helper for Agentex. + +Converts LangGraph graph.astream() events into Agentex streaming updates +and pushes them to Redis via adk.streaming contexts. For use with async +ACP agents that stream via Redis rather than HTTP yields. +""" + + +async def stream_langgraph_events(stream, task_id: str) -> str: + """Stream LangGraph events to Agentex via Redis. + + Processes the stream from graph.astream() called with + stream_mode=["messages", "updates"] and pushes text, reasoning, + tool request, and tool response messages through Redis streaming + contexts. + + Supports both regular models (chunk.content is a str) and reasoning + models like gpt-5/o1/o3 (chunk.content is a list of typed content blocks + in the Responses API responses/v1 format). + + Args: + stream: Async iterator from graph.astream(..., stream_mode=["messages", "updates"]) + task_id: The Agentex task ID to stream messages to. + + Returns: + The accumulated final text output from the agent. + """ + # Lazy imports so langgraph/langchain aren't required at module load time + from langchain_core.messages import AIMessageChunk, ToolMessage + + from agentex.lib import adk + from agentex.types.reasoning_content import ReasoningContent + from agentex.types.task_message_delta import TextDelta + from agentex.types.reasoning_summary_delta import ReasoningSummaryDelta + from agentex.types.task_message_update import StreamTaskMessageDelta + from agentex.types.text_content import TextContent + from agentex.types.tool_request_content import ToolRequestContent + from agentex.types.tool_response_content import ToolResponseContent + + text_context = None + reasoning_context = None + final_text = "" + + try: + async for event_type, event_data in stream: + if event_type == "messages": + chunk, metadata = event_data + + if not isinstance(chunk, AIMessageChunk) or not chunk.content: + continue + + # ---------------------------------------------------------- + # Case 1: content is a plain string (regular models) + # ---------------------------------------------------------- + if isinstance(chunk.content, str): + if reasoning_context: + await reasoning_context.close() + reasoning_context = None + + if not text_context: + final_text = "" + text_context = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=TextContent( + author="agent", + content="", + format="markdown", + ), + ).__aenter__() + + final_text += chunk.content + await text_context.stream_update( + StreamTaskMessageDelta( + parent_task_message=text_context.task_message, + delta=TextDelta(type="text", text_delta=chunk.content), + type="delta", + ) + ) + + # ---------------------------------------------------------- + # Case 2: content is a list of typed blocks (reasoning models) + # Responses API (responses/v1) format: + # {"type": "reasoning", "summary": [{"type": "summary_text", "text": "..."}]} + # {"type": "text", "text": "..."} + # ---------------------------------------------------------- + elif isinstance(chunk.content, list): + for block in chunk.content: + if not isinstance(block, dict): + continue + + block_type = block.get("type") + + if block_type == "reasoning": + reasoning_text = "" + for s in block.get("summary", []): + if isinstance(s, dict) and s.get("type") == "summary_text": + reasoning_text += s.get("text", "") + if not reasoning_text: + continue + + if text_context: + await text_context.close() + text_context = None + + if not reasoning_context: + reasoning_context = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=ReasoningContent( + author="agent", + summary=[], + content=[], + type="reasoning", + style="active", + ), + ).__aenter__() + + await reasoning_context.stream_update( + StreamTaskMessageDelta( + parent_task_message=reasoning_context.task_message, + delta=ReasoningSummaryDelta( + type="reasoning_summary", + summary_index=0, + summary_delta=reasoning_text, + ), + type="delta", + ) + ) + + elif block_type == "text": + text_delta = block.get("text", "") + if not text_delta: + continue + + if reasoning_context: + await reasoning_context.close() + reasoning_context = None + + if not text_context: + final_text = "" + text_context = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=TextContent( + author="agent", + content="", + format="markdown", + ), + ).__aenter__() + + final_text += text_delta + await text_context.stream_update( + StreamTaskMessageDelta( + parent_task_message=text_context.task_message, + delta=TextDelta(type="text", text_delta=text_delta), + type="delta", + ) + ) + + elif event_type == "updates": + for node_name, state_update in event_data.items(): + if node_name == "agent": + messages = state_update.get("messages", []) + for msg in messages: + if text_context: + await text_context.close() + text_context = None + if reasoning_context: + await reasoning_context.close() + reasoning_context = None + + if hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + await adk.messages.create( + task_id=task_id, + content=ToolRequestContent( + tool_call_id=tc["id"], + name=tc["name"], + arguments=tc["args"], + author="agent", + ), + ) + + elif node_name == "tools": + messages = state_update.get("messages", []) + for msg in messages: + if isinstance(msg, ToolMessage): + await adk.messages.create( + task_id=task_id, + content=ToolResponseContent( + tool_call_id=msg.tool_call_id, + name=msg.name or "unknown", + content=msg.content if isinstance(msg.content, str) else str(msg.content), + author="agent", + ), + ) + finally: + # Always close open contexts + if text_context: + await text_context.close() + if reasoning_context: + await reasoning_context.close() + + return final_text diff --git a/src/agentex/lib/adk/_modules/_langgraph_tracing.py b/src/agentex/lib/adk/_modules/_langgraph_tracing.py new file mode 100644 index 000000000..656cdc91e --- /dev/null +++ b/src/agentex/lib/adk/_modules/_langgraph_tracing.py @@ -0,0 +1,239 @@ +"""LangChain callback handler that creates Agentex spans for LLM calls and tool executions.""" + +from __future__ import annotations + +import json +from typing import Any, Sequence +from uuid import UUID + +from langchain_core.callbacks import AsyncCallbackHandler +from langchain_core.messages import BaseMessage +from langchain_core.outputs import LLMResult + +from agentex.lib.adk._modules.tracing import TracingModule +from agentex.lib.utils.logging import make_logger +from agentex.types.span import Span + +logger = make_logger(__name__) + + +class AgentexLangGraphTracingHandler(AsyncCallbackHandler): + """Async LangChain callback handler that records Agentex tracing spans. + + Creates child spans under a parent span for each LLM call and tool execution. + Designed to be passed via ``config={"callbacks": [handler]}`` to LangGraph's + ``graph.astream()`` or ``graph.ainvoke()``. + + Span hierarchy produced:: + + (e.g. "message" turn-level span) + ├── llm: (LLM call) + ├── tool: (tool execution) + └── llm: (LLM call) + """ + + def __init__( + self, + trace_id: str, + parent_span_id: str | None = None, + tracing: TracingModule | None = None, + ) -> None: + super().__init__() + self._trace_id = trace_id + self._parent_span_id = parent_span_id + # Lazily initialise TracingModule so the httpx client is created + # inside the *running* event-loop (not at import/construction time). + self._tracing_eager = tracing + self._tracing_lazy: TracingModule | None = None + # Map run_id → Span for in-flight spans + self._spans: dict[UUID, Span] = {} + + @property + def _tracing(self) -> TracingModule: + if self._tracing_eager is not None: + return self._tracing_eager + if self._tracing_lazy is None: + self._tracing_lazy = TracingModule() + return self._tracing_lazy + + # ------------------------------------------------------------------ + # LLM lifecycle + # ------------------------------------------------------------------ + + async def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + model_name = (metadata or {}).get("ls_model_name", "") or _extract_model_name(serialized) + span = await self._tracing.start_span( + trace_id=self._trace_id, + name=f"llm:{model_name}" if model_name else "llm", + input=_serialize_messages(messages), + parent_id=self._parent_span_id, + data={"__span_type__": "COMPLETION"}, + ) + if span: + self._spans[run_id] = span + + async def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + span = self._spans.pop(run_id, None) + if span is None: + return + span.output = _serialize_llm_result(response) + await self._tracing.end_span(trace_id=self._trace_id, span=span) + + async def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + span = self._spans.pop(run_id, None) + if span is None: + return + span.output = {"error": str(error)} + await self._tracing.end_span(trace_id=self._trace_id, span=span) + + # ------------------------------------------------------------------ + # Tool lifecycle + # ------------------------------------------------------------------ + + async def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + tool_name = serialized.get("name", "") or serialized.get("id", [""])[-1] + span = await self._tracing.start_span( + trace_id=self._trace_id, + name=f"tool:{tool_name}" if tool_name else "tool", + input={"input": input_str}, + parent_id=self._parent_span_id, + data={"__span_type__": "CUSTOM"}, + ) + if span: + self._spans[run_id] = span + + async def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + span = self._spans.pop(run_id, None) + if span is None: + return + span.output = {"output": output} + await self._tracing.end_span(trace_id=self._trace_id, span=span) + + async def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + span = self._spans.pop(run_id, None) + if span is None: + return + span.output = {"error": str(error)} + await self._tracing.end_span(trace_id=self._trace_id, span=span) + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _extract_model_name(serialized: dict[str, Any]) -> str: + """Best-effort model name extraction from the serialized callback dict.""" + kwargs = serialized.get("kwargs", {}) + return kwargs.get("model_name", "") or kwargs.get("model", "") + + +def _serialize_messages(messages: list[list[BaseMessage]]) -> dict[str, Any]: + """Serialize LangChain messages into a JSON-safe dict for the span input.""" + result: list[dict[str, Any]] = [] + for batch in messages: + for msg in batch: + entry: dict[str, Any] = {"type": msg.type, "content": msg.content} + if hasattr(msg, "tool_calls") and msg.tool_calls: + entry["tool_calls"] = msg.tool_calls + result.append(entry) + return {"messages": result} + + +def _serialize_llm_result(response: LLMResult) -> dict[str, Any]: + """Serialize an LLMResult into a JSON-safe dict for the span output.""" + output: dict[str, Any] = {} + if response.generations: + last_gen = response.generations[-1] + if last_gen: + gen = last_gen[-1] + msg = getattr(gen, "message", None) + + # For reasoning models, content is a list of typed blocks. + # Extract text from the blocks instead of relying on gen.text. + if msg and isinstance(msg.content, list): + text_parts: list[str] = [] + for block in msg.content: + if isinstance(block, dict): + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + output["content"] = "".join(text_parts) if text_parts else gen.text + else: + output["content"] = gen.text + + if msg and hasattr(msg, "tool_calls") and msg.tool_calls: + output["tool_calls"] = [ + {"name": tc["name"], "args": tc["args"]} + for tc in msg.tool_calls + ] + return output + + +def create_langgraph_tracing_handler( + trace_id: str, + parent_span_id: str | None = None, +) -> AgentexLangGraphTracingHandler: + """Create a LangChain callback handler that records Agentex tracing spans. + + Pass the returned handler to LangGraph via ``config={"callbacks": [handler]}``. + + Args: + trace_id: The trace ID (typically the task/thread ID). + parent_span_id: Optional parent span ID to nest LLM/tool spans under. + + Returns: + An ``AgentexLangGraphTracingHandler`` instance ready to use as a LangChain callback. + """ + return AgentexLangGraphTracingHandler( + trace_id=trace_id, + parent_span_id=parent_span_id, + ) diff --git a/src/agentex/lib/adk/_modules/checkpointer.py b/src/agentex/lib/adk/_modules/checkpointer.py new file mode 100644 index 000000000..544042941 --- /dev/null +++ b/src/agentex/lib/adk/_modules/checkpointer.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from agentex.lib.adk.utils._modules.client import create_async_agentex_client +from agentex.lib.adk._modules._http_checkpointer import HttpCheckpointSaver + + +async def create_checkpointer() -> HttpCheckpointSaver: + """Create an HTTP-proxy checkpointer for LangGraph. + + Checkpoint operations are proxied through the agentex backend API. + No direct database connection needed — auth is handled via the + agent API key (injected automatically by agentex). + + Usage: + checkpointer = await create_checkpointer() + graph = builder.compile(checkpointer=checkpointer) + """ + client = create_async_agentex_client() + return HttpCheckpointSaver(client=client) diff --git a/src/agentex/lib/adk/_modules/tracing.py b/src/agentex/lib/adk/_modules/tracing.py index 93fd2365e..cb3b4b22b 100644 --- a/src/agentex/lib/adk/_modules/tracing.py +++ b/src/agentex/lib/adk/_modules/tracing.py @@ -39,14 +39,51 @@ def __init__(self, tracing_service: TracingService | None = None): Initialize the tracing interface. Args: - tracing_activities (Optional[TracingActivities]): Optional pre-configured tracing activities. If None, will be auto-initialized. + tracing_service (Optional[TracingService]): Optional pre-configured tracing service. + If None, will be lazily created on first use so the httpx client is + bound to the correct running event loop. """ - if tracing_service is None: - agentex_client = create_async_agentex_client() + self._tracing_service_explicit = tracing_service + self._tracing_service_lazy: TracingService | None = None + self._bound_loop_id: int | None = None + + @property + def _tracing_service(self) -> TracingService: + if self._tracing_service_explicit is not None: + return self._tracing_service_explicit + + import asyncio + + # Determine the current event loop (if any). + try: + loop = asyncio.get_running_loop() + loop_id = id(loop) + except RuntimeError: + loop_id = None + + # Re-create the underlying httpx client when the event loop changes + # (e.g. between HTTP requests in a sync ASGI server) to avoid + # "Event loop is closed" / "bound to a different event loop" errors. + if self._tracing_service_lazy is None or ( + loop_id is not None and loop_id != self._bound_loop_id + ): + import httpx + + # Disable keepalive so each span HTTP call gets a fresh TCP + # connection. Reused connections carry asyncio primitives bound + # to the event loop that created them; in sync-ACP / streaming + # contexts the loop context can shift between calls, causing + # "bound to a different event loop" RuntimeErrors. + agentex_client = create_async_agentex_client( + http_client=httpx.AsyncClient( + limits=httpx.Limits(max_keepalive_connections=0), + ), + ) tracer = AsyncTracer(agentex_client) - self._tracing_service = TracingService(tracer=tracer) - else: - self._tracing_service = tracing_service + self._tracing_service_lazy = TracingService(tracer=tracer) + self._bound_loop_id = loop_id + + return self._tracing_service_lazy @asynccontextmanager async def span( diff --git a/src/agentex/lib/core/services/adk/tracing.py b/src/agentex/lib/core/services/adk/tracing.py index 7e55c7501..210d2f625 100644 --- a/src/agentex/lib/core/services/adk/tracing.py +++ b/src/agentex/lib/core/services/adk/tracing.py @@ -24,14 +24,14 @@ async def start_span( data: list[Any] | dict[str, Any] | BaseModel | None = None, ) -> Span | None: trace = self._tracer.trace(trace_id) - async with trace.span( - parent_id=parent_id, + span = await trace.start_span( name=name, + parent_id=parent_id, input=input or {}, data=data, - ) as span: - heartbeat_if_in_workflow("start span") - return span if span else None + ) + heartbeat_if_in_workflow("start span") + return span async def end_span(self, trace_id: str, span: Span) -> Span: trace = self._tracer.trace(trace_id) diff --git a/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py b/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py index 56e4fa340..54e0d1187 100644 --- a/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py @@ -66,7 +66,18 @@ def shutdown(self) -> None: class AgentexAsyncTracingProcessor(AsyncTracingProcessor): def __init__(self, config: AgentexTracingProcessorConfig): # noqa: ARG002 - self.client = create_async_agentex_client() + import httpx + + # Disable keepalive so each span HTTP call gets a fresh TCP connection. + # Reused connections carry asyncio primitives bound to the event loop + # that created them; in sync-ACP / streaming contexts the loop context + # can shift between calls, causing "bound to a different event loop" + # RuntimeErrors. + self.client = create_async_agentex_client( + http_client=httpx.AsyncClient( + limits=httpx.Limits(max_keepalive_connections=0), + ), + ) @override async def on_span_start(self, span: Span) -> None: diff --git a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py index 8a298121d..e5686f0f5 100644 --- a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py @@ -17,6 +17,13 @@ logger = make_logger(__name__) +def _get_span_type(span: Span) -> str: + """Read span_type from span.data['__span_type__'], defaulting to STANDALONE.""" + if isinstance(span.data, dict): + return span.data.get("__span_type__", "STANDALONE") + return "STANDALONE" + + class SGPSyncTracingProcessor(SyncTracingProcessor): def __init__(self, config: SGPTracingProcessorConfig): disabled = config.sgp_api_key == "" or config.sgp_account_id == "" @@ -46,9 +53,10 @@ def _add_source_to_span(self, span: Span) -> None: @override def on_span_start(self, span: Span) -> None: self._add_source_to_span(span) - + sgp_span = create_span( name=span.name, + span_type=_get_span_type(span), span_id=span.id, parent_id=span.parent_id, trace_id=span.trace_id, @@ -86,11 +94,18 @@ class SGPAsyncTracingProcessor(AsyncTracingProcessor): def __init__(self, config: SGPTracingProcessorConfig): self.disabled = config.sgp_api_key == "" or config.sgp_account_id == "" self._spans: dict[str, SGPSpan] = {} + import httpx + + # Disable keepalive so each HTTP call gets a fresh TCP connection, + # avoiding "bound to a different event loop" errors in sync-ACP. self.sgp_async_client = ( AsyncSGPClient( - api_key=config.sgp_api_key, + api_key=config.sgp_api_key, account_id=config.sgp_account_id, base_url=config.sgp_base_url, + http_client=httpx.AsyncClient( + limits=httpx.Limits(max_keepalive_connections=0), + ), ) if not self.disabled else None @@ -114,6 +129,7 @@ async def on_span_start(self, span: Span) -> None: self._add_source_to_span(span) sgp_span = create_span( name=span.name, + span_type=_get_span_type(span), span_id=span.id, parent_id=span.parent_id, trace_id=span.trace_id, diff --git a/src/agentex/lib/types/converters.py b/src/agentex/lib/types/converters.py index 1e3676b55..276e3b0bb 100644 --- a/src/agentex/lib/types/converters.py +++ b/src/agentex/lib/types/converters.py @@ -62,3 +62,222 @@ def convert_task_messages_to_oai_agents_inputs( ) return converted_messages + + +async def convert_langgraph_to_agentex_events(stream): + """Convert LangGraph streaming events to Agentex TaskMessageUpdate events. + + Expects the stream from graph.astream() called with + stream_mode=["messages", "updates"]. This produces two event types: + + ("messages", (message_chunk, metadata)) — token-by-token LLM output + ("updates", {node_name: state_update}) — complete node outputs + + Text tokens are streamed as Start/Delta/Done sequences. + Reasoning tokens are streamed as Start/Delta/Done sequences with ReasoningContentDelta. + Tool calls and tool results are emitted as Full messages. + + Supports both regular models (chunk.content is a str) and reasoning models + like gpt-5/o1/o3 (chunk.content is a list of typed content blocks). + + Args: + stream: Async iterator from graph.astream(..., stream_mode=["messages", "updates"]) + + Yields: + TaskMessageUpdate events (Start, Delta, Done, Full) + """ + # Lazy imports so langgraph/langchain aren't required at module load time + from langchain_core.messages import AIMessageChunk, ToolMessage + + from agentex.types.reasoning_content_delta import ReasoningContentDelta + from agentex.types.reasoning_summary_delta import ReasoningSummaryDelta + from agentex.types.task_message_delta import TextDelta + from agentex.types.task_message_update import ( + StreamTaskMessageStart, + StreamTaskMessageDelta, + StreamTaskMessageDone, + StreamTaskMessageFull, + ) + + message_index = 0 + text_streaming = False + reasoning_streaming = False + reasoning_content_index = 0 + + async for event_type, event_data in stream: + if event_type == "messages": + chunk, metadata = event_data + + if not isinstance(chunk, AIMessageChunk) or not chunk.content: + continue + + # ---------------------------------------------------------- + # Case 1: content is a plain string (regular models) + # ---------------------------------------------------------- + if isinstance(chunk.content, str): + # Close reasoning stream if we're transitioning to text + if reasoning_streaming: + yield StreamTaskMessageDone(type="done", index=message_index) + reasoning_streaming = False + message_index += 1 + + if not text_streaming: + yield StreamTaskMessageStart( + type="start", + index=message_index, + content=TextContent(type="text", author="agent", content=""), + ) + text_streaming = True + + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=TextDelta(type="text", text_delta=chunk.content), + ) + + # ---------------------------------------------------------- + # Case 2: content is a list of typed blocks (reasoning models) + # Responses API (responses/v1) format: + # {"type": "reasoning", "summary": [{"type": "summary_text", "text": "..."}]} + # {"type": "text", "text": "..."} + # ---------------------------------------------------------- + elif isinstance(chunk.content, list): + for block in chunk.content: + if not isinstance(block, dict): + continue + + block_type = block.get("type") + + if block_type == "reasoning": + # Responses API: reasoning text is inside summary list + reasoning_text = "" + summaries = block.get("summary", []) + for s in summaries: + if isinstance(s, dict) and s.get("type") == "summary_text": + reasoning_text += s.get("text", "") + if not reasoning_text: + continue + + # Close text stream if transitioning to reasoning + if text_streaming: + yield StreamTaskMessageDone(type="done", index=message_index) + text_streaming = False + message_index += 1 + + if not reasoning_streaming: + yield StreamTaskMessageStart( + type="start", + index=message_index, + content=TextContent(type="text", author="agent", content=""), + ) + reasoning_streaming = True + reasoning_content_index = 0 + + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ReasoningContentDelta( + type="reasoning_content", + content_index=reasoning_content_index, + content_delta=reasoning_text, + ), + ) + + elif block_type == "text": + text_delta = block.get("text", "") + if not text_delta: + continue + + # Close reasoning stream if transitioning to text + if reasoning_streaming: + yield StreamTaskMessageDone(type="done", index=message_index) + reasoning_streaming = False + reasoning_content_index += 1 + message_index += 1 + + if not text_streaming: + yield StreamTaskMessageStart( + type="start", + index=message_index, + content=TextContent(type="text", author="agent", content=""), + ) + text_streaming = True + + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=TextDelta(type="text", text_delta=text_delta), + ) + + # ---------------------------------------------------------- + # Reasoning summaries via additional_kwargs (OpenAI v0.3 format) + # ---------------------------------------------------------- + additional_kwargs = getattr(chunk, "additional_kwargs", {}) + reasoning_kw = additional_kwargs.get("reasoning") + if isinstance(reasoning_kw, dict): + summaries = reasoning_kw.get("summary", []) + for si, summary_item in enumerate(summaries): + if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text": + summary_text = summary_item.get("text", "") + if summary_text: + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ReasoningSummaryDelta( + type="reasoning_summary", + summary_index=si, + summary_delta=summary_text, + ), + ) + + elif event_type == "updates": + for node_name, state_update in event_data.items(): + if node_name == "agent": + messages = state_update.get("messages", []) + for msg in messages: + # Close any open streams + if text_streaming: + yield StreamTaskMessageDone(type="done", index=message_index) + text_streaming = False + message_index += 1 + if reasoning_streaming: + yield StreamTaskMessageDone(type="done", index=message_index) + reasoning_streaming = False + message_index += 1 + + # Emit tool requests if the agent decided to call tools + if hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + yield StreamTaskMessageFull( + type="full", + index=message_index, + content=ToolRequestContent( + tool_call_id=tc["id"], + name=tc["name"], + arguments=tc["args"], + author="agent", + ), + ) + message_index += 1 + + elif node_name == "tools": + messages = state_update.get("messages", []) + for msg in messages: + if isinstance(msg, ToolMessage): + yield StreamTaskMessageFull( + type="full", + index=message_index, + content=ToolResponseContent( + tool_call_id=msg.tool_call_id, + name=msg.name or "unknown", + content=msg.content if isinstance(msg.content, str) else str(msg.content), + author="agent", + ), + ) + message_index += 1 + + # Close any remaining open streams + if text_streaming: + yield StreamTaskMessageDone(type="done", index=message_index) + if reasoning_streaming: + yield StreamTaskMessageDone(type="done", index=message_index)