diff --git a/src/uipath_langchain/agent/tools/durable_interrupt/__init__.py b/src/uipath_langchain/_utils/durable_interrupt/__init__.py similarity index 75% rename from src/uipath_langchain/agent/tools/durable_interrupt/__init__.py rename to src/uipath_langchain/_utils/durable_interrupt/__init__.py index c814b0be3..bd36440fb 100644 --- a/src/uipath_langchain/agent/tools/durable_interrupt/__init__.py +++ b/src/uipath_langchain/_utils/durable_interrupt/__init__.py @@ -1,6 +1,9 @@ """Durable interrupt package for side-effect-safe interrupt/resume in LangGraph.""" -from .decorator import _durable_state, durable_interrupt +from .decorator import ( + _durable_state, + durable_interrupt, +) from .skip_interrupt import SkipInterruptValue __all__ = [ diff --git a/src/uipath_langchain/agent/tools/durable_interrupt/decorator.py b/src/uipath_langchain/_utils/durable_interrupt/decorator.py similarity index 100% rename from src/uipath_langchain/agent/tools/durable_interrupt/decorator.py rename to src/uipath_langchain/_utils/durable_interrupt/decorator.py diff --git a/src/uipath_langchain/agent/tools/durable_interrupt/skip_interrupt.py b/src/uipath_langchain/_utils/durable_interrupt/skip_interrupt.py similarity index 100% rename from src/uipath_langchain/agent/tools/durable_interrupt/skip_interrupt.py rename to src/uipath_langchain/_utils/durable_interrupt/skip_interrupt.py diff --git a/src/uipath_langchain/agent/tools/context_tool.py b/src/uipath_langchain/agent/tools/context_tool.py index 51837c45c..95fd4961a 100644 --- a/src/uipath_langchain/agent/tools/context_tool.py +++ b/src/uipath_langchain/agent/tools/context_tool.py @@ -26,6 +26,7 @@ from uipath.runtime.errors import UiPathErrorCategory from uipath_langchain._utils import get_execution_folder_path +from uipath_langchain._utils.durable_interrupt import durable_interrupt from uipath_langchain.agent.exceptions import AgentStartupError, AgentStartupErrorCode from uipath_langchain.agent.react.jsonschema_pydantic_converter import ( create_model as create_model_from_schema, @@ -40,7 +41,6 @@ ) from uipath_langchain.retrievers import ContextGroundingRetriever -from .durable_interrupt import durable_interrupt from .structured_tool_with_argument_properties import ( StructuredToolWithArgumentProperties, ) diff --git a/src/uipath_langchain/agent/tools/escalation_tool.py b/src/uipath_langchain/agent/tools/escalation_tool.py index ead8fdef3..1410c2e68 100644 --- a/src/uipath_langchain/agent/tools/escalation_tool.py +++ b/src/uipath_langchain/agent/tools/escalation_tool.py @@ -21,6 +21,7 @@ from uipath.runtime.errors import UiPathErrorCategory from uipath_langchain._utils import get_execution_folder_path +from uipath_langchain._utils.durable_interrupt import durable_interrupt from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model from uipath_langchain.agent.tools.static_args import ( handle_static_args, @@ -31,7 +32,6 @@ from ..exceptions import AgentRuntimeError, AgentRuntimeErrorCode from ..react.types import AgentGraphState -from .durable_interrupt import durable_interrupt from .tool_node import ToolWrapperReturnType from .utils import ( resolve_task_title, diff --git a/src/uipath_langchain/agent/tools/internal_tools/batch_transform_tool.py b/src/uipath_langchain/agent/tools/internal_tools/batch_transform_tool.py index 6dcb55add..0b61b30ba 100644 --- a/src/uipath_langchain/agent/tools/internal_tools/batch_transform_tool.py +++ b/src/uipath_langchain/agent/tools/internal_tools/batch_transform_tool.py @@ -26,13 +26,13 @@ ) from uipath.runtime.errors import UiPathErrorCategory -from uipath_langchain.agent.exceptions import AgentStartupError, AgentStartupErrorCode -from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model -from uipath_langchain.agent.react.types import AgentGraphState -from uipath_langchain.agent.tools.durable_interrupt import ( +from uipath_langchain._utils.durable_interrupt import ( SkipInterruptValue, durable_interrupt, ) +from uipath_langchain.agent.exceptions import AgentStartupError, AgentStartupErrorCode +from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model +from uipath_langchain.agent.react.types import AgentGraphState from uipath_langchain.agent.tools.internal_tools.schema_utils import ( BATCH_TRANSFORM_OUTPUT_SCHEMA, add_query_field_to_schema, diff --git a/src/uipath_langchain/agent/tools/internal_tools/deeprag_tool.py b/src/uipath_langchain/agent/tools/internal_tools/deeprag_tool.py index 9effde108..6537a38ed 100644 --- a/src/uipath_langchain/agent/tools/internal_tools/deeprag_tool.py +++ b/src/uipath_langchain/agent/tools/internal_tools/deeprag_tool.py @@ -22,13 +22,13 @@ ) from uipath.runtime.errors import UiPathErrorCategory -from uipath_langchain.agent.exceptions import AgentStartupError, AgentStartupErrorCode -from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model -from uipath_langchain.agent.react.types import AgentGraphState -from uipath_langchain.agent.tools.durable_interrupt import ( +from uipath_langchain._utils.durable_interrupt import ( SkipInterruptValue, durable_interrupt, ) +from uipath_langchain.agent.exceptions import AgentStartupError, AgentStartupErrorCode +from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model +from uipath_langchain.agent.react.types import AgentGraphState from uipath_langchain.agent.tools.internal_tools.schema_utils import ( add_query_field_to_schema, ) diff --git a/src/uipath_langchain/agent/tools/ixp_escalation_tool.py b/src/uipath_langchain/agent/tools/ixp_escalation_tool.py index 68be1339e..98d593d7d 100644 --- a/src/uipath_langchain/agent/tools/ixp_escalation_tool.py +++ b/src/uipath_langchain/agent/tools/ixp_escalation_tool.py @@ -18,6 +18,7 @@ ) from uipath.runtime.errors import UiPathErrorCategory +from uipath_langchain._utils.durable_interrupt import durable_interrupt from uipath_langchain.agent.react.types import AgentGraphState from uipath_langchain.agent.tools.tool_node import ( ToolWrapperMixin, @@ -25,7 +26,6 @@ ) from ..exceptions import AgentRuntimeError, AgentRuntimeErrorCode -from .durable_interrupt import durable_interrupt from .structured_tool_with_output_type import StructuredToolWithOutputType from .utils import ( resolve_task_title, diff --git a/src/uipath_langchain/agent/tools/process_tool.py b/src/uipath_langchain/agent/tools/process_tool.py index 7bf9e647a..233e1b060 100644 --- a/src/uipath_langchain/agent/tools/process_tool.py +++ b/src/uipath_langchain/agent/tools/process_tool.py @@ -13,6 +13,7 @@ from uipath.platform.orchestrator import JobState from uipath_langchain._utils import get_execution_folder_path +from uipath_langchain._utils.durable_interrupt import durable_interrupt from uipath_langchain.agent.react.job_attachments import get_job_attachments from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model from uipath_langchain.agent.react.types import AgentGraphState @@ -24,7 +25,6 @@ ToolWrapperReturnType, ) -from .durable_interrupt import durable_interrupt from .utils import sanitize_tool_name diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index 8a87fec87..8f7bc5ca3 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -16,6 +16,8 @@ LowCodeAgentDefinition, ) +from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION + from .context_tool import create_context_tool from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool @@ -54,6 +56,15 @@ async def create_tools_from_resources( else: tools.append(tool) + if agent.is_conversational: + props = getattr(resource, "properties", None) + if props and getattr( + props, REQUIRE_CONVERSATIONAL_CONFIRMATION, False + ): + if tool.metadata is None: + tool.metadata = {} + tool.metadata[REQUIRE_CONVERSATIONAL_CONFIRMATION] = True + return tools diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index eb187ebab..0bcfe5a2c 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -22,6 +22,7 @@ extract_current_tool_call_index, find_latest_ai_message, ) +from uipath_langchain.chat.hitl import request_conversational_tool_confirmation # the type safety can be improved with generics ToolWrapperReturnType = dict[str, Any] | Command[Any] | None @@ -80,6 +81,15 @@ def _func(self, state: AgentGraphState) -> OutputType: if call is None: return None + # prompt user for approval if tool requires confirmation + conversational_confirmation = request_conversational_tool_confirmation( + call, self.tool + ) + if conversational_confirmation: + if conversational_confirmation.cancelled: + # tool confirmation rejected + return self._process_result(call, conversational_confirmation.cancelled) + try: if self.wrapper: inputs = self._prepare_wrapper_inputs( @@ -88,7 +98,11 @@ def _func(self, state: AgentGraphState) -> OutputType: result = self.wrapper(*inputs) else: result = self.tool.invoke(call) - return self._process_result(call, result) + output = self._process_result(call, result) + if conversational_confirmation: + # HITL approved - apply confirmation metadata to tool result message + conversational_confirmation.annotate_result(output) + return output except GraphBubbleUp: # LangGraph uses exceptions for interrupt control flow — re-raise so # handle_tool_errors doesn't swallow expected interrupts as errors. @@ -104,15 +118,29 @@ async def _afunc(self, state: AgentGraphState) -> OutputType: if call is None: return None + # prompt user for approval if tool requires confirmation + conversational_confirmation = request_conversational_tool_confirmation( + call, self.tool + ) + if conversational_confirmation: + if conversational_confirmation.cancelled: + # tool confirmation rejected + return self._process_result(call, conversational_confirmation.cancelled) + try: if self.awrapper: inputs = self._prepare_wrapper_inputs( self.awrapper, self.tool, call, state ) + result = await self.awrapper(*inputs) else: result = await self.tool.ainvoke(call) - return self._process_result(call, result) + output = self._process_result(call, result) + if conversational_confirmation: + # HITL approved - apply confirmation metadata to tool result message + conversational_confirmation.annotate_result(output) + return output except GraphBubbleUp: # LangGraph uses exceptions for interrupt control flow — re-raise so # handle_tool_errors doesn't swallow expected interrupts as errors. diff --git a/src/uipath_langchain/chat/hitl.py b/src/uipath_langchain/chat/hitl.py index 625fc9a63..228d1b365 100644 --- a/src/uipath_langchain/chat/hitl.py +++ b/src/uipath_langchain/chat/hitl.py @@ -1,16 +1,66 @@ import functools import inspect +import json from inspect import Parameter -from typing import Annotated, Any, Callable +from typing import Annotated, Any, Callable, NamedTuple +from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.tools import BaseTool, InjectedToolCallId from langchain_core.tools import tool as langchain_tool -from langgraph.types import interrupt from uipath.core.chat import ( UiPathConversationToolCallConfirmationValue, ) -_CANCELLED_MESSAGE = "Cancelled by user" +from uipath_langchain._utils.durable_interrupt import durable_interrupt + +CANCELLED_MESSAGE = "Cancelled by user" + +CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args" +REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation" + + +class ConfirmationResult(NamedTuple): + """Result of a tool confirmation check.""" + + cancelled: ToolMessage | None # ToolMessage if cancelled, None if approved + args_modified: bool + approved_args: dict[str, Any] | None = None + + def annotate_result(self, output: dict[str, Any] | Any) -> None: + """Apply confirmation metadata to a tool result message.""" + msg = None + if isinstance(output, dict): + messages = output.get("messages") + if messages: + msg = messages[0] + else: + # Tools with @durable_interrupt return a Command whose messages + # are nested under output.update["messages"]. + update = getattr(output, "update", None) + if isinstance(update, dict): + messages = update.get("messages") + if messages: + msg = messages[0] + if msg is None: + return + if self.approved_args is not None: + msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = ( + self.approved_args + ) + if self.args_modified: + try: + result_value = json.loads(msg.content) + except (json.JSONDecodeError, TypeError): + result_value = msg.content + msg.content = json.dumps( + { + "meta": { + "args_modified_by_user": True, + "executed_args": self.approved_args, + }, + "result": result_value, + } + ) def _patch_span_input(approved_args: dict[str, Any]) -> None: @@ -53,7 +103,7 @@ def _patch_span_input(approved_args: dict[str, Any]) -> None: pass -def _request_approval( +def request_approval( tool_args: dict[str, Any], tool: BaseTool, ) -> dict[str, Any] | None: @@ -70,14 +120,16 @@ def _request_approval( if tool_call_schema is not None: input_schema = tool_call_schema.model_json_schema() - response = interrupt( - UiPathConversationToolCallConfirmationValue( + @durable_interrupt + def ask_confirmation(): + return UiPathConversationToolCallConfirmationValue( tool_call_id=tool_call_id, tool_name=tool.name, input_schema=input_schema, input_value=tool_args, ) - ) + + response = ask_confirmation() # The resume payload from CAS has shape: # {"type": "uipath_cas_tool_call_confirmation", @@ -89,9 +141,46 @@ def _request_approval( if not confirmation.get("approved", True): return None - return confirmation.get("input") or tool_args + return ( + confirmation.get("input") + if confirmation.get("input") is not None + else tool_args + ) + +# for conversational low code agents +def request_conversational_tool_confirmation( + call: ToolCall, tool: BaseTool +) -> ConfirmationResult | None: + """Check whether a tool requires user confirmation and request approval""" + if not (tool.metadata and tool.metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION)): + return None + original_args = call["args"] + approved_args = request_approval( + {**original_args, "tool_call_id": call["id"]}, tool + ) + if approved_args is None: + cancelled_msg = ToolMessage( + content=json.dumps({"meta": CANCELLED_MESSAGE}), + name=call["name"], + tool_call_id=call["id"], + ) + cancelled_msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = ( + original_args + ) + return ConfirmationResult(cancelled=cancelled_msg, args_modified=False) + + # Mutate call args so the tool executes with the approved values + call["args"] = approved_args + return ConfirmationResult( + cancelled=None, + args_modified=approved_args != original_args, + approved_args=approved_args, + ) + + +# for conversational coded agents def requires_approval( func: Callable[..., Any] | None = None, *, @@ -107,9 +196,10 @@ def decorator(fn: Callable[..., Any]) -> BaseTool: # wrap the tool/function @functools.wraps(fn) def wrapper(**tool_args: Any) -> Any: - approved_args = _request_approval(tool_args, _created_tool[0]) + approved_args = request_approval(tool_args, _created_tool[0]) if approved_args is None: - return _CANCELLED_MESSAGE + return json.dumps({"meta": CANCELLED_MESSAGE}) + _patch_span_input(approved_args) return fn(**approved_args) diff --git a/src/uipath_langchain/runtime/messages.py b/src/uipath_langchain/runtime/messages.py index 53712e912..43aba1626 100644 --- a/src/uipath_langchain/runtime/messages.py +++ b/src/uipath_langchain/runtime/messages.py @@ -58,6 +58,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None """Initialize the mapper with empty state.""" self.runtime_id = runtime_id self.storage = storage + self.tool_names_requiring_confirmation: set[str] = set() self.current_message: AIMessageChunk self.seen_message_ids: set[str] = set() self._storage_lock = asyncio.Lock() @@ -389,11 +390,17 @@ async def map_current_message_to_start_tool_call_events(self): tool_call_id_to_message_id_map[tool_call_id] = ( self.current_message.id ) - events.append( - self.map_tool_call_to_tool_call_start_event( - self.current_message.id, tool_call + + # if tool requires confirmation, we skip start tool call + if ( + tool_call["name"] + not in self.tool_names_requiring_confirmation + ): + events.append( + self.map_tool_call_to_tool_call_start_event( + self.current_message.id, tool_call + ) ) - ) if self.storage is not None: await self.storage.set_value( @@ -665,7 +672,7 @@ def _map_langchain_ai_message_to_uipath_message_data( role="assistant", content_parts=content_parts, tool_calls=uipath_tool_calls, - interrupts=[], # TODO: Interrupts + interrupts=[], ) diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index 228a5cdb9..feb327018 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -29,6 +29,7 @@ ) from uipath.runtime.schema import UiPathRuntimeSchema +from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError from uipath_langchain.runtime.messages import UiPathChatMessagesMapper from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema @@ -64,6 +65,9 @@ def __init__( self.entrypoint: str | None = entrypoint self.callbacks: list[BaseCallbackHandler] = callbacks or [] self.chat = UiPathChatMessagesMapper(self.runtime_id, storage) + self.chat.tool_names_requiring_confirmation = ( + self._get_tool_names_requiring_confirmation() + ) self._middleware_node_names: set[str] = self._detect_middleware_nodes() async def execute( @@ -486,6 +490,18 @@ def _detect_middleware_nodes(self) -> set[str]: return middleware_nodes + def _get_tool_names_requiring_confirmation(self) -> set[str]: + names: set[str] = set() + for node_name, node_spec in self.graph.nodes.items(): + # langgraph's processing node.bound -> runnable.tool -> baseTool (if tool node) + tool = getattr(getattr(node_spec, "bound", None), "tool", None) + if tool is None: + continue + metadata = getattr(tool, "metadata", None) or {} + if metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION): + names.add(getattr(tool, "name", node_name)) + return names + def _is_middleware_node(self, node_name: str) -> bool: """Check if a node name represents a middleware node.""" return node_name in self._middleware_node_names diff --git a/tests/agent/tools/internal_tools/test_batch_transform_tool.py b/tests/agent/tools/internal_tools/test_batch_transform_tool.py index 5831913f0..6b3567868 100644 --- a/tests/agent/tools/internal_tools/test_batch_transform_tool.py +++ b/tests/agent/tools/internal_tools/test_batch_transform_tool.py @@ -150,7 +150,7 @@ def resource_config_dynamic(self, batch_transform_settings_dynamic_query): "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPathConfig" ) @patch("uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch( "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.mockable", lambda **kwargs: lambda f: f, @@ -242,7 +242,7 @@ async def test_create_batch_transform_tool_static_query_index_ready( "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPathConfig" ) @patch("uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch( "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.mockable", lambda **kwargs: lambda f: f, @@ -323,7 +323,7 @@ async def test_create_batch_transform_tool_static_query_wait_for_ingestion( "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPathConfig" ) @patch("uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch( "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.mockable", lambda **kwargs: lambda f: f, @@ -395,7 +395,7 @@ async def test_create_batch_transform_tool_dynamic_query( "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPathConfig" ) @patch("uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch( "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.mockable", lambda **kwargs: lambda f: f, @@ -475,7 +475,7 @@ async def test_create_batch_transform_tool_default_destination_path( "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPathConfig" ) @patch("uipath_langchain.agent.tools.internal_tools.batch_transform_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch( "uipath_langchain.agent.tools.internal_tools.batch_transform_tool.mockable", lambda **kwargs: lambda f: f, diff --git a/tests/agent/tools/internal_tools/test_deeprag_tool.py b/tests/agent/tools/internal_tools/test_deeprag_tool.py index 3934bb73e..1d2a9d2ef 100644 --- a/tests/agent/tools/internal_tools/test_deeprag_tool.py +++ b/tests/agent/tools/internal_tools/test_deeprag_tool.py @@ -122,7 +122,7 @@ def resource_config_dynamic(self, deeprag_settings_dynamic_query): "uipath_langchain.agent.wrappers.job_attachment_wrapper.get_job_attachment_wrapper" ) @patch("uipath_langchain.agent.tools.internal_tools.deeprag_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch( "uipath_langchain.agent.tools.internal_tools.deeprag_tool.mockable", lambda **kwargs: lambda f: f, @@ -192,7 +192,7 @@ async def test_create_deeprag_tool_static_query_index_ready( "uipath_langchain.agent.wrappers.job_attachment_wrapper.get_job_attachment_wrapper" ) @patch("uipath_langchain.agent.tools.internal_tools.deeprag_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch( "uipath_langchain.agent.tools.internal_tools.deeprag_tool.mockable", lambda **kwargs: lambda f: f, @@ -257,7 +257,7 @@ async def test_create_deeprag_tool_static_query_wait_for_ingestion( "uipath_langchain.agent.wrappers.job_attachment_wrapper.get_job_attachment_wrapper" ) @patch("uipath_langchain.agent.tools.internal_tools.deeprag_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch( "uipath_langchain.agent.tools.internal_tools.deeprag_tool.mockable", lambda **kwargs: lambda f: f, diff --git a/tests/agent/tools/test_context_tool.py b/tests/agent/tools/test_context_tool.py index eaaabf00d..da867a9d6 100644 --- a/tests/agent/tools/test_context_tool.py +++ b/tests/agent/tools/test_context_tool.py @@ -218,7 +218,7 @@ async def test_tool_with_different_citation_modes(self, base_resource_config): tool = handle_deep_rag("test_tool", resource) with patch( - "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" + "uipath_langchain._utils.durable_interrupt.decorator.interrupt" ) as mock_interrupt: mock_interrupt.return_value = {"mocked": "response"} assert tool.coroutine is not None @@ -240,7 +240,7 @@ async def test_unique_task_names_on_multiple_invocations( task_names = [] with patch( - "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" + "uipath_langchain._utils.durable_interrupt.decorator.interrupt" ) as mock_interrupt: mock_interrupt.return_value = {"mocked": "response"} @@ -301,7 +301,7 @@ async def test_dynamic_query_uses_provided_query(self, base_resource_config): tool = handle_deep_rag("test_tool", resource) with patch( - "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" + "uipath_langchain._utils.durable_interrupt.decorator.interrupt" ) as mock_interrupt: mock_interrupt.return_value = {"mocked": "response"} assert tool.coroutine is not None @@ -322,7 +322,7 @@ async def test_deep_rag_uses_execution_folder_path(self, base_resource_config): tool = handle_deep_rag("test_tool", resource) with patch( - "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" + "uipath_langchain._utils.durable_interrupt.decorator.interrupt" ) as mock_interrupt: mock_interrupt.return_value = {"mocked": "response"} assert tool.coroutine is not None @@ -693,7 +693,7 @@ async def test_static_query_batch_transform_uses_predefined_query( mock_uipath.jobs.create_attachment_async = AsyncMock(return_value="att-id-1") with ( patch( - "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" + "uipath_langchain._utils.durable_interrupt.decorator.interrupt" ) as mock_interrupt, patch( "uipath_langchain.agent.tools.context_tool.UiPath", @@ -741,7 +741,7 @@ async def test_dynamic_query_batch_transform_uses_provided_query(self): mock_uipath.jobs.create_attachment_async = AsyncMock(return_value="att-id-2") with ( patch( - "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" + "uipath_langchain._utils.durable_interrupt.decorator.interrupt" ) as mock_interrupt, patch( "uipath_langchain.agent.tools.context_tool.UiPath", @@ -769,7 +769,7 @@ async def test_static_query_batch_transform_uses_default_destination_path( mock_uipath.jobs.create_attachment_async = AsyncMock(return_value="att-id-3") with ( patch( - "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" + "uipath_langchain._utils.durable_interrupt.decorator.interrupt" ) as mock_interrupt, patch( "uipath_langchain.agent.tools.context_tool.UiPath", @@ -818,7 +818,7 @@ async def test_dynamic_query_batch_transform_uses_default_destination_path(self) mock_uipath.jobs.create_attachment_async = AsyncMock(return_value="att-id-4") with ( patch( - "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" + "uipath_langchain._utils.durable_interrupt.decorator.interrupt" ) as mock_interrupt, patch( "uipath_langchain.agent.tools.context_tool.UiPath", @@ -846,7 +846,7 @@ async def test_batch_transform_uses_execution_folder_path( mock_uipath.jobs.create_attachment_async = AsyncMock(return_value="att-id") with ( patch( - "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" + "uipath_langchain._utils.durable_interrupt.decorator.interrupt" ) as mock_interrupt, patch( "uipath_langchain.agent.tools.context_tool.UiPath", diff --git a/tests/agent/tools/test_durable_interrupt.py b/tests/agent/tools/test_durable_interrupt.py index 200c419d9..6f6825561 100644 --- a/tests/agent/tools/test_durable_interrupt.py +++ b/tests/agent/tools/test_durable_interrupt.py @@ -7,7 +7,7 @@ import pytest from langgraph._internal._constants import CONFIG_KEY_SCRATCHPAD -from uipath_langchain.agent.tools.durable_interrupt import ( +from uipath_langchain._utils.durable_interrupt import ( _durable_state, durable_interrupt, ) @@ -27,8 +27,8 @@ def _make_config(scratchpad: FakeScratchpad | None = None) -> dict[str, Any]: return {"configurable": {CONFIG_KEY_SCRATCHPAD: scratchpad}} -PATCH_GET_CONFIG = "uipath_langchain.agent.tools.durable_interrupt.decorator.get_config" -PATCH_INTERRUPT = "uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt" +PATCH_GET_CONFIG = "uipath_langchain._utils.durable_interrupt.decorator.get_config" +PATCH_INTERRUPT = "uipath_langchain._utils.durable_interrupt.decorator.interrupt" @pytest.fixture(autouse=True) diff --git a/tests/agent/tools/test_escalation_tool.py b/tests/agent/tools/test_escalation_tool.py index 67c843ecc..8106b56f1 100644 --- a/tests/agent/tools/test_escalation_tool.py +++ b/tests/agent/tools/test_escalation_tool.py @@ -288,7 +288,7 @@ async def test_escalation_tool_metadata_has_channel_type(self, escalation_resour @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_metadata_has_recipient( self, mock_interrupt, mock_uipath_class, escalation_resource ): @@ -317,7 +317,7 @@ async def test_escalation_tool_metadata_has_recipient( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_metadata_recipient_none_when_no_recipients( self, mock_interrupt, mock_uipath_class, escalation_resource_no_recipient ): @@ -342,7 +342,7 @@ async def test_escalation_tool_metadata_recipient_none_when_no_recipients( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_with_string_task_title( self, mock_interrupt, mock_uipath_class ): @@ -392,7 +392,7 @@ async def test_escalation_tool_with_string_task_title( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_with_text_builder_task_title( self, mock_interrupt, mock_uipath_class ): @@ -450,7 +450,7 @@ async def test_escalation_tool_with_text_builder_task_title( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_with_empty_task_title_defaults_to_escalation_task( self, mock_interrupt, mock_uipath_class ): @@ -548,7 +548,7 @@ async def test_escalation_tool_output_schema_has_action_field( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_result_validation( self, mock_interrupt, mock_uipath_class, escalation_resource ): @@ -576,7 +576,7 @@ async def test_escalation_tool_result_validation( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_extracts_action_from_result( self, mock_interrupt, mock_uipath_class, escalation_resource ): @@ -600,7 +600,7 @@ async def test_escalation_tool_extracts_action_from_result( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_raises_when_task_is_deleted( self, mock_interrupt, mock_uipath_class, escalation_resource ): @@ -625,7 +625,7 @@ async def test_escalation_tool_raises_when_task_is_deleted( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_dict_result_without_is_deleted_defaults_to_false( self, mock_interrupt, mock_uipath_class, escalation_resource ): @@ -650,7 +650,7 @@ async def test_escalation_tool_dict_result_without_is_deleted_defaults_to_false( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_escalation_tool_with_outcome_mapping_end( self, mock_interrupt, mock_uipath_class ): @@ -830,7 +830,7 @@ def escalation_resource(self): @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_creates_task_then_interrupts_with_wait_escalation( self, mock_interrupt, mock_uipath_class, escalation_resource ): @@ -865,7 +865,7 @@ async def test_creates_task_then_interrupts_with_wait_escalation( @pytest.mark.asyncio @patch.dict(os.environ, {"UIPATH_FOLDER_PATH": "/Test/Folder"}) @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_creates_task_with_execution_folder_path( self, mock_interrupt, mock_uipath_class, escalation_resource ): diff --git a/tests/agent/tools/test_ixp_escalation_tool.py b/tests/agent/tools/test_ixp_escalation_tool.py index 5e7dab19a..723f4aad4 100644 --- a/tests/agent/tools/test_ixp_escalation_tool.py +++ b/tests/agent/tools/test_ixp_escalation_tool.py @@ -187,7 +187,7 @@ def mock_state_without_extraction(self): @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.ixp_escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_wrapper_retrieves_extraction_from_state( self, mock_interrupt, @@ -287,7 +287,7 @@ async def test_wrapper_looks_for_correct_ixp_tool_id( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.ixp_escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_wrapper_raises_on_document_rejection( self, mock_interrupt, @@ -368,7 +368,7 @@ def mock_extraction_response(self): @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.ixp_escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_tool_calls_interrupt_with_correct_params( self, mock_interrupt, @@ -409,7 +409,7 @@ async def test_tool_calls_interrupt_with_correct_params( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.ixp_escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_tool_uses_default_action_title_when_not_provided( self, mock_interrupt, mock_uipath_cls, mock_extraction_response ): @@ -462,7 +462,7 @@ async def test_tool_uses_default_action_title_when_not_provided( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.ixp_escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_tool_uses_default_priority_when_not_provided( self, mock_interrupt, mock_uipath_cls, mock_extraction_response ): @@ -515,7 +515,7 @@ async def test_tool_uses_default_priority_when_not_provided( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.ixp_escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_tool_returns_data_projection_as_dict( self, mock_interrupt, @@ -543,7 +543,7 @@ async def test_tool_returns_data_projection_as_dict( @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.ixp_escalation_tool.UiPath") - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") async def test_tool_stores_validation_response_in_metadata( self, mock_interrupt, diff --git a/tests/agent/tools/test_langgraph_interrupt_contract.py b/tests/agent/tools/test_langgraph_interrupt_contract.py index 0231d818a..792309d9e 100644 --- a/tests/agent/tools/test_langgraph_interrupt_contract.py +++ b/tests/agent/tools/test_langgraph_interrupt_contract.py @@ -170,7 +170,7 @@ class TestDurableInterruptAlignment: def test_single_durable_interrupt_returns_resume_value(self) -> None: """On resume, durable_interrupt returns the resume value directly.""" - from uipath_langchain.agent.tools.durable_interrupt import ( + from uipath_langchain._utils.durable_interrupt import ( _durable_state, durable_interrupt, ) @@ -196,7 +196,7 @@ def create_job() -> dict[str, str]: def test_two_durable_interrupts_return_sequential_resume_values(self) -> None: """Two @durable_interrupt calls return resume values by index.""" - from uipath_langchain.agent.tools.durable_interrupt import ( + from uipath_langchain._utils.durable_interrupt import ( _durable_state, durable_interrupt, ) @@ -224,7 +224,7 @@ def task_b() -> str: def test_partial_resume_first_returns_value_second_raises(self) -> None: """One resume value: first returns it, second runs body and raises GraphInterrupt.""" - from uipath_langchain.agent.tools.durable_interrupt import ( + from uipath_langchain._utils.durable_interrupt import ( _durable_state, durable_interrupt, ) @@ -256,7 +256,7 @@ def task_b() -> str: async def test_async_durable_interrupt_returns_resume_value(self) -> None: """Async variant: durable_interrupt returns resume value directly.""" - from uipath_langchain.agent.tools.durable_interrupt import ( + from uipath_langchain._utils.durable_interrupt import ( _durable_state, durable_interrupt, ) diff --git a/tests/agent/tools/test_process_tool.py b/tests/agent/tools/test_process_tool.py index f701e2500..3f1f7642c 100644 --- a/tests/agent/tools/test_process_tool.py +++ b/tests/agent/tools/test_process_tool.py @@ -117,7 +117,7 @@ class TestProcessToolInvocation: @pytest.mark.asyncio @patch.dict(os.environ, {"UIPATH_FOLDER_PATH": "/Shared/MyFolder"}) - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch("uipath_langchain.agent.tools.process_tool.UiPath") async def test_invoke_calls_processes_invoke_async( self, mock_uipath_class, mock_interrupt, process_resource @@ -150,7 +150,7 @@ async def test_invoke_calls_processes_invoke_async( ) @pytest.mark.asyncio - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch("uipath_langchain.agent.tools.process_tool.UiPath") async def test_invoke_interrupts_with_wait_job( self, mock_uipath_class, mock_interrupt, process_resource @@ -181,7 +181,7 @@ async def test_invoke_interrupts_with_wait_job( @pytest.mark.asyncio @patch.dict(os.environ, {"UIPATH_FOLDER_PATH": "/Shared/DataFolder"}) - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch("uipath_langchain.agent.tools.process_tool.UiPath") async def test_invoke_passes_input_arguments( self, mock_uipath_class, mock_interrupt, process_resource_with_inputs @@ -210,7 +210,7 @@ async def test_invoke_passes_input_arguments( assert call_kwargs["folder_path"] == "/Shared/DataFolder" @pytest.mark.asyncio - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch("uipath_langchain.agent.tools.process_tool.UiPath") async def test_invoke_returns_output_from_extract( self, mock_uipath_class, mock_interrupt, process_resource @@ -238,7 +238,7 @@ async def test_invoke_returns_output_from_extract( assert result == {"output_arg": "value123"} @pytest.mark.asyncio - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch("uipath_langchain.agent.tools.process_tool.UiPath") async def test_invoke_returns_error_message_on_faulted_job( self, mock_uipath_class, mock_interrupt, process_resource @@ -270,7 +270,7 @@ class TestProcessToolSpanContext: """Test that _span_context is properly wired for tracing.""" @pytest.mark.asyncio - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch("uipath_langchain.agent.tools.process_tool.UiPath") async def test_span_context_parent_span_id_passed_to_invoke( self, mock_uipath_class, mock_interrupt, process_resource @@ -302,7 +302,7 @@ async def test_span_context_parent_span_id_passed_to_invoke( assert call_kwargs["parent_span_id"] == "span-abc-123" @pytest.mark.asyncio - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch("uipath_langchain.agent.tools.process_tool.UiPath") async def test_span_context_consumed_after_invoke( self, mock_uipath_class, mock_interrupt, process_resource @@ -332,7 +332,7 @@ async def test_span_context_consumed_after_invoke( assert "parent_span_id" not in tool.metadata["_span_context"] @pytest.mark.asyncio - @patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @patch("uipath_langchain.agent.tools.process_tool.UiPath") async def test_span_context_defaults_to_none_when_empty( self, mock_uipath_class, mock_interrupt, process_resource diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index af3da38cb..870cedf18 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -1,6 +1,8 @@ """Tests for tool_node.py module.""" +import json from typing import Any, Dict +from unittest.mock import patch import pytest from langchain_core.messages import AIMessage, HumanMessage @@ -13,11 +15,16 @@ AgentRuntimeError, AgentRuntimeErrorCode, ) +from uipath_langchain.agent.react.types import AgentGraphState from uipath_langchain.agent.tools.tool_node import ( ToolWrapperMixin, UiPathToolNode, create_tool_node, ) +from uipath_langchain.chat.hitl import ( + CANCELLED_MESSAGE, + CONVERSATIONAL_APPROVED_TOOL_ARGS, +) class MockTool(BaseTool): @@ -66,10 +73,9 @@ class FilteredState(BaseModel): session_id: str = "test_session" -class MockState(BaseModel): +class MockState(AgentGraphState): """Mock state for testing.""" - messages: list[Any] = [] user_id: str = "test_user" session_id: str = "test_session" @@ -310,7 +316,7 @@ def test_tool_error_propagates_when_handle_errors_false(self, mock_state): node = UiPathToolNode(failing_tool, handle_tool_errors=False) with pytest.raises(ValueError) as exc_info: - node._func(state) # type: ignore[arg-type] + node._func(state) assert "Tool execution failed: test input" in str(exc_info.value) @@ -328,7 +334,7 @@ async def test_async_tool_error_propagates_when_handle_errors_false(self): node = UiPathToolNode(failing_tool, handle_tool_errors=False) with pytest.raises(ValueError) as exc_info: - await node._afunc(state) # type: ignore[arg-type] + await node._afunc(state) assert "Async tool execution failed: test input" in str(exc_info.value) @@ -345,7 +351,7 @@ def test_tool_error_captured_when_handle_errors_true(self): node = UiPathToolNode(failing_tool, handle_tool_errors=True) - result = node._func(state) # type: ignore[arg-type] + result = node._func(state) assert result is not None assert isinstance(result, dict) @@ -372,7 +378,7 @@ async def test_async_tool_error_captured_when_handle_errors_true(self): node = UiPathToolNode(failing_tool, handle_tool_errors=True) - result = await node._afunc(state) # type: ignore[arg-type] + result = await node._afunc(state) assert result is not None assert isinstance(result, dict) @@ -482,3 +488,185 @@ def test_create_tool_node_with_handle_errors_true(self): node = result[tool_name] assert isinstance(node, UiPathToolNode) assert node.handle_tool_errors is True + + +class TestToolNodeConfirmation: + """Tests for confirmation flow in UiPathToolNode._func / _afunc.""" + + @pytest.fixture + def confirmation_tool(self): + """Tool with require_conversational_confirmation metadata.""" + return MockTool(metadata={"require_conversational_confirmation": True}) + + @pytest.fixture + def confirmation_state(self): + tool_call = { + "name": "mock_tool", + "args": {"input_text": "test input"}, + "id": "test_call_id", + } + ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) + return MockState(messages=[ai_message]) + + def test_no_confirmation_without_metadata(self): + """Tool without metadata executes normally, no interrupt.""" + tool = MockTool() # no metadata + node = UiPathToolNode(tool) + tool_call = { + "name": "mock_tool", + "args": {"input_text": "hello"}, + "id": "call_1", + } + state = MockState(messages=[AIMessage(content="go", tool_calls=[tool_call])]) + + result = node._func(state) + + assert result is not None + assert isinstance(result, dict) + assert "Mock result: hello" in result["messages"][0].content + + @patch("uipath_langchain.chat.hitl.request_approval", return_value=None) + def test_cancelled_returns_cancelled_message( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Rejected confirmation returns CANCELLED_MESSAGE.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert isinstance(msg, ToolMessage) + assert msg.content == json.dumps({"meta": CANCELLED_MESSAGE}) + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"input_text": "test input"}, + ) + def test_approved_same_args_no_meta( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Approved with same args → normal execution, no meta injected.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert "args_modified_by_user" not in msg.content + assert "Mock result:" in msg.content + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"input_text": "edited"}, + ) + def test_approved_modified_args_injects_meta( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Approved with edited args → tool runs with new args, meta injected.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + + assert isinstance(msg.content, str) + wrapped = json.loads(msg.content) + assert wrapped["meta"]["args_modified_by_user"] is True + assert wrapped["meta"]["executed_args"] == {"input_text": "edited"} + assert "Mock result: edited" in wrapped["result"] + + @patch("uipath_langchain.chat.hitl.request_approval", return_value=None) + async def test_async_cancelled( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Async path: rejected confirmation returns CANCELLED_MESSAGE.""" + node = UiPathToolNode(confirmation_tool) + + result = await node._afunc(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert msg.content == json.dumps({"meta": CANCELLED_MESSAGE}) + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"input_text": "async edited"}, + ) + async def test_async_approved_modified_args( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Async path: approved with edited args → meta injected.""" + node = UiPathToolNode(confirmation_tool) + + result = await node._afunc(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + + assert isinstance(msg.content, str) + wrapped = json.loads(msg.content) + assert wrapped["meta"]["args_modified_by_user"] is True + assert wrapped["meta"]["executed_args"] == {"input_text": "async edited"} + assert "Async mock result: async edited" in wrapped["result"] + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"input_text": "approved"}, + ) + def test_approved_attaches_approved_args_metadata( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Approved path attaches approved args in response_metadata.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert CONVERSATIONAL_APPROVED_TOOL_ARGS in msg.response_metadata + assert msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] == { + "input_text": "approved" + } + + @patch("uipath_langchain.chat.hitl.request_approval", return_value=None) + def test_cancelled_attaches_original_args_metadata( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Cancelled path attaches original args in response_metadata.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert CONVERSATIONAL_APPROVED_TOOL_ARGS in msg.response_metadata + assert msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] == { + "input_text": "test input" + } + + def test_no_confirmation_no_metadata(self): + """Non-confirmation tools don't get the approved args metadata.""" + tool = MockTool() # no confirmation metadata + node = UiPathToolNode(tool) + tool_call = { + "name": "mock_tool", + "args": {"input_text": "hello"}, + "id": "call_1", + } + state = MockState(messages=[AIMessage(content="go", tool_calls=[tool_call])]) + + result = node._func(state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert CONVERSATIONAL_APPROVED_TOOL_ARGS not in msg.response_metadata diff --git a/tests/chat/test_hitl.py b/tests/chat/test_hitl.py new file mode 100644 index 000000000..5ef910324 --- /dev/null +++ b/tests/chat/test_hitl.py @@ -0,0 +1,187 @@ +"""Tests for hitl.py module.""" + +import json +from typing import Any +from unittest.mock import patch + +from langchain_core.messages.tool import ToolCall, ToolMessage +from langchain_core.tools import BaseTool + +from uipath_langchain.chat.hitl import ( + CANCELLED_MESSAGE, + CONVERSATIONAL_APPROVED_TOOL_ARGS, + ConfirmationResult, + request_approval, + request_conversational_tool_confirmation, +) + + +class MockTool(BaseTool): + name: str = "mock_tool" + description: str = "A mock tool" + + def _run(self) -> str: + return "" + + +def _make_call(args: dict[str, Any] | None = None) -> ToolCall: + return ToolCall(name="mock_tool", args=args or {"query": "test"}, id="call_1") + + +class TestCheckToolConfirmation: + """Tests for request_conversational_tool_confirmation.""" + + def test_returns_none_when_no_metadata(self): + """No metadata → no confirmation needed.""" + tool = MockTool() + call = _make_call() + assert request_conversational_tool_confirmation(call, tool) is None + + def test_returns_none_when_flag_not_set(self): + """Metadata exists but flag is missing → no confirmation needed.""" + tool = MockTool(metadata={"other_key": True}) + call = _make_call() + assert request_conversational_tool_confirmation(call, tool) is None + + def test_returns_none_when_flag_false(self): + """Flag explicitly False → no confirmation needed.""" + tool = MockTool(metadata={"require_conversational_confirmation": False}) + call = _make_call() + assert request_conversational_tool_confirmation(call, tool) is None + + @patch("uipath_langchain.chat.hitl.request_approval", return_value=None) + def test_cancelled_returns_tool_message(self, mock_approval): + """User rejects → ConfirmationResult with cancelled ToolMessage and metadata.""" + tool = MockTool(metadata={"require_conversational_confirmation": True}) + call = _make_call() + + result = request_conversational_tool_confirmation(call, tool) + + assert result is not None + assert isinstance(result, ConfirmationResult) + assert result.cancelled is not None + assert isinstance(result.cancelled, ToolMessage) + assert result.cancelled.content == json.dumps({"meta": CANCELLED_MESSAGE}) + assert result.cancelled.name == "mock_tool" + assert result.cancelled.tool_call_id == "call_1" + assert result.args_modified is False + assert result.cancelled.response_metadata[ + CONVERSATIONAL_APPROVED_TOOL_ARGS + ] == {"query": "test"} + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"query": "test"}, + ) + def test_approved_same_args(self, mock_approval): + """User approves without editing → cancelled=None, args_modified=False.""" + tool = MockTool(metadata={"require_conversational_confirmation": True}) + call = _make_call({"query": "test"}) + + result = request_conversational_tool_confirmation(call, tool) + + assert result is not None + assert result.cancelled is None + assert result.args_modified is False + assert result.approved_args == {"query": "test"} + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"query": "edited"}, + ) + def test_approved_modified_args(self, mock_approval): + """User edits args → cancelled=None, args_modified=True, call updated.""" + tool = MockTool(metadata={"require_conversational_confirmation": True}) + call = _make_call({"query": "original"}) + + result = request_conversational_tool_confirmation(call, tool) + + assert result is not None + assert result.cancelled is None + assert result.args_modified is True + assert result.approved_args == {"query": "edited"} + assert call["args"] == {"query": "edited"} + + +class TestAnnotateResult: + """Tests for ConfirmationResult.annotate_result.""" + + def test_annotate_sets_metadata(self): + """annotate_result sets approved_args on response_metadata.""" + confirmation = ConfirmationResult( + cancelled=None, args_modified=False, approved_args={"query": "test"} + ) + msg = ToolMessage(content="result", tool_call_id="call_1") + output = {"messages": [msg]} + + confirmation.annotate_result(output) + + assert msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] == { + "query": "test" + } + assert msg.content == "result" + + def test_annotate_wraps_content_when_modified(self): + """annotate_result wraps content with structured meta when args were modified.""" + confirmation = ConfirmationResult( + cancelled=None, args_modified=True, approved_args={"query": "edited"} + ) + msg = ToolMessage(content="result", tool_call_id="call_1") + output = {"messages": [msg]} + + confirmation.annotate_result(output) + + assert msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] == { + "query": "edited" + } + import json + + assert isinstance(msg.content, str) + wrapped = json.loads(msg.content) + assert wrapped["meta"]["args_modified_by_user"] is True + assert wrapped["meta"]["executed_args"] == {"query": "edited"} + assert wrapped["result"] == "result" + + +class TestRequestApprovalTruthiness: + """Tests for the truthiness fix in request_approval.""" + + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") + def test_empty_dict_input_preserved(self, mock_interrupt): + """Empty dict from user edits should not be replaced by original args.""" + mock_interrupt.return_value = {"value": {"approved": True, "input": {}}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result == {} + + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") + def test_empty_list_input_preserved(self, mock_interrupt): + """Empty list from user edits should not be replaced by original args.""" + mock_interrupt.return_value = {"value": {"approved": True, "input": []}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result == [] + + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") + def test_none_input_falls_back_to_original(self, mock_interrupt): + """None input should fall back to original tool_args.""" + mock_interrupt.return_value = {"value": {"approved": True, "input": None}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result == {"query": "test"} + + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") + def test_missing_input_falls_back_to_original(self, mock_interrupt): + """Missing input key should fall back to original tool_args.""" + mock_interrupt.return_value = {"value": {"approved": True}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result == {"query": "test"} + + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") + def test_rejected_returns_none(self, mock_interrupt): + """Rejected approval returns None.""" + mock_interrupt.return_value = {"value": {"approved": False}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result is None diff --git a/tests/runtime/test_chat_message_mapper.py b/tests/runtime/test_chat_message_mapper.py index 3eabe5e66..35db6a912 100644 --- a/tests/runtime/test_chat_message_mapper.py +++ b/tests/runtime/test_chat_message_mapper.py @@ -1718,3 +1718,129 @@ def test_ai_message_with_media_citation(self): assert isinstance(source, UiPathConversationCitationSourceMedia) assert source.download_url == "https://r.com" assert source.page_number == "3" + + +class TestConfirmationToolDeferral: + """Tests for deferring startToolCall events for confirmation tools.""" + + @pytest.mark.asyncio + async def test_start_tool_call_skipped_for_confirmation_tool(self): + """AIMessageChunk with confirmation tool should NOT emit startToolCall.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tool_names_requiring_confirmation = {"confirm_tool"} + + # First chunk starts the message with a confirmation tool call + first_chunk = AIMessageChunk( + content="", + id="msg-1", + tool_calls=[{"id": "tc-1", "name": "confirm_tool", "args": {"x": 1}}], + ) + await mapper.map_event(first_chunk) + + # Last chunk triggers tool call start events + last_chunk = AIMessageChunk(content="", id="msg-1") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + tool_start_events = [ + e + for e in result + if e.tool_call is not None and e.tool_call.start is not None + ] + assert len(tool_start_events) == 0 + + @pytest.mark.asyncio + async def test_start_tool_call_emitted_for_non_confirmation_tool(self): + """Normal tools still emit startToolCall even when confirmation set is populated.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tool_names_requiring_confirmation = {"other_tool"} + + first_chunk = AIMessageChunk( + content="", + id="msg-2", + tool_calls=[{"id": "tc-2", "name": "normal_tool", "args": {}}], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-2") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + tool_start_events = [ + e + for e in result + if e.tool_call is not None and e.tool_call.start is not None + ] + assert len(tool_start_events) >= 1 + assert tool_start_events[0].tool_call is not None + assert tool_start_events[0].tool_call.start is not None + assert tool_start_events[0].tool_call.start.tool_name == "normal_tool" + + @pytest.mark.asyncio + async def test_confirmation_tool_message_emits_only_end(self): + """ToolMessage for a confirmation tool should only emit endToolCall + messageEnd. + + startToolCall is now emitted by the bridge on HITL approval, not here. + """ + storage = create_mock_storage() + storage.get_value.return_value = {"tc-3": "msg-3"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tool_names_requiring_confirmation = {"confirm_tool"} + + tool_msg = ToolMessage( + content='{"result": "ok"}', + tool_call_id="tc-3", + name="confirm_tool", + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + # Should have: endToolCall, messageEnd (no startToolCall) + assert len(result) == 2 + + # First event: endToolCall + end_event = result[0] + assert end_event.tool_call is not None + assert end_event.tool_call.end is not None + + # Second event: messageEnd + assert result[1].end is not None + + @pytest.mark.asyncio + async def test_mixed_tools_only_confirmation_deferred(self): + """Mixed tools in one AIMessage: only confirmation tool's startToolCall is deferred.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tool_names_requiring_confirmation = {"confirm_tool"} + + first_chunk = AIMessageChunk( + content="", + id="msg-4", + tool_calls=[ + {"id": "tc-normal", "name": "normal_tool", "args": {"a": 1}}, + {"id": "tc-confirm", "name": "confirm_tool", "args": {"b": 2}}, + ], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-4") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + tool_start_names = [ + e.tool_call.start.tool_name + for e in result + if e.tool_call is not None and e.tool_call.start is not None + ] + # normal_tool should have startToolCall, confirm_tool should NOT + assert "normal_tool" in tool_start_names + assert "confirm_tool" not in tool_start_names