Skip to content

Commit eeeaf00

Browse files
committed
LangGraph: Unify plugin for Graph API and Functional API
- Create unified LangGraphPlugin supporting both StateGraph and @entrypoint - Add auto-detection to compile() for returning correct runner type - Fix _is_entrypoint() to distinguish @entrypoint (Pregel) from StateGraph.compile() (CompiledStateGraph) by checking class type and presence of __start__ node - Fix timedelta serialization in _filter_config() by excluding temporal options from metadata (handled separately by _get_node_activity_options) - Update tests to use real graphs instead of MagicMock
1 parent 02c4ca7 commit eeeaf00

10 files changed

Lines changed: 618 additions & 177 deletions

File tree

temporalio/contrib/langgraph/__init__.py

Lines changed: 175 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99

1010
from datetime import timedelta
11-
from typing import Any
11+
from typing import TYPE_CHECKING, Any, Union
1212

1313
import temporalio.common
1414
import temporalio.workflow
@@ -19,24 +19,37 @@
1919
GraphAlreadyRegisteredError,
2020
)
2121
from temporalio.contrib.langgraph._functional_activity import execute_langgraph_task
22-
from temporalio.contrib.langgraph._functional_plugin import LangGraphFunctionalPlugin
22+
23+
# Backward compatibility - LangGraphFunctionalPlugin is now deprecated
24+
# Use LangGraphPlugin with entrypoints in the graphs parameter instead
25+
from temporalio.contrib.langgraph._functional_plugin import (
26+
LangGraphFunctionalPlugin,
27+
)
2328
from temporalio.contrib.langgraph._functional_registry import (
2429
get_entrypoint,
2530
register_entrypoint,
2631
)
32+
from temporalio.contrib.langgraph._functional_registry import (
33+
get_global_entrypoint_registry as _get_functional_registry,
34+
)
2735
from temporalio.contrib.langgraph._functional_runner import (
2836
TemporalFunctionalRunner,
29-
compile_functional,
3037
)
3138
from temporalio.contrib.langgraph._graph_registry import (
3239
get_default_activity_options,
3340
get_graph,
3441
get_per_node_activity_options,
3542
)
43+
from temporalio.contrib.langgraph._graph_registry import (
44+
get_global_registry as _get_graph_registry,
45+
)
3646
from temporalio.contrib.langgraph._models import StateSnapshot
3747
from temporalio.contrib.langgraph._plugin import LangGraphPlugin
3848
from temporalio.contrib.langgraph._runner import TemporalLangGraphRunner
3949

50+
if TYPE_CHECKING:
51+
from temporalio.contrib.langgraph._plugin import ActivityOptionsKey
52+
4053

4154
def activity_options(
4255
*,
@@ -53,14 +66,7 @@ def activity_options(
5366
) -> dict[str, Any]:
5467
"""Create activity options for LangGraph integration.
5568
56-
Use with Graph API:
57-
- ``graph.add_node(metadata=activity_options(...))`` for node activities
58-
- ``LangGraphPlugin(per_node_activity_options={"node": activity_options(...)})``
59-
60-
Use with Functional API:
61-
- ``compile_functional(task_options={"task_name": activity_options(...)})``
62-
- ``LangGraphFunctionalPlugin(task_options={"task": activity_options(...)})``
63-
69+
Use with plugin registration or compile() for workflow-level overrides.
6470
Parameters mirror ``workflow.execute_activity()``.
6571
"""
6672
config: dict[str, Any] = {}
@@ -95,7 +101,7 @@ def temporal_node_metadata(
95101
"""Create node metadata combining activity options and execution flags.
96102
97103
Args:
98-
activity_options: Options from ``node_activity_options()``.
104+
activity_options: Options from ``activity_options()``.
99105
run_in_workflow: If True, run in workflow instead of as activity.
100106
"""
101107
# Start with activity options if provided, otherwise empty temporal config
@@ -118,23 +124,75 @@ def compile(
118124
graph_id: str,
119125
*,
120126
default_activity_options: dict[str, Any] | None = None,
121-
per_node_activity_options: dict[str, dict[str, Any]] | None = None,
127+
activity_options: dict[str, dict[str, Any]] | None = None,
122128
checkpoint: dict | None = None,
123-
) -> TemporalLangGraphRunner:
124-
"""Compile a registered graph for Temporal execution.
129+
) -> Union[TemporalLangGraphRunner, TemporalFunctionalRunner]:
130+
"""Compile a registered graph or entrypoint for Temporal execution.
131+
132+
This function auto-detects whether the ID refers to a Graph API graph
133+
(StateGraph) or a Functional API entrypoint (@entrypoint/@task).
125134
126135
.. warning::
127136
This API is experimental and may change in future versions.
128137
129138
Args:
130-
graph_id: ID of graph registered with LangGraphPlugin.
131-
default_activity_options: Default options for all nodes.
132-
per_node_activity_options: Per-node options by node name.
139+
graph_id: ID of graph or entrypoint registered with LangGraphPlugin.
140+
default_activity_options: Default options for all nodes/tasks.
141+
Use activity_options() helper to create.
142+
activity_options: Per-node/task options by name.
143+
Use activity_options() helper to create values.
133144
checkpoint: Checkpoint from previous get_state() for continue-as-new.
145+
Only applies to Graph API graphs.
146+
147+
Returns:
148+
TemporalLangGraphRunner for Graph API graphs, or
149+
TemporalFunctionalRunner for Functional API entrypoints.
134150
135151
Raises:
136-
ApplicationError: If no graph with the given ID is registered.
152+
ApplicationError: If no graph or entrypoint with the given ID is registered.
137153
"""
154+
# Check which registry has this ID
155+
graph_registry = _get_graph_registry()
156+
functional_registry = _get_functional_registry()
157+
158+
is_graph = graph_registry.is_registered(graph_id)
159+
is_entrypoint = functional_registry.is_registered(graph_id)
160+
161+
if is_graph:
162+
return _compile_graph(
163+
graph_id,
164+
default_activity_options=default_activity_options,
165+
per_node_activity_options=activity_options,
166+
checkpoint=checkpoint,
167+
)
168+
elif is_entrypoint:
169+
return _compile_entrypoint(
170+
graph_id,
171+
default_activity_options=default_activity_options,
172+
task_options=activity_options,
173+
)
174+
else:
175+
# Neither registry has it - raise error
176+
from temporalio.exceptions import ApplicationError
177+
178+
graph_ids = graph_registry.list_graphs()
179+
entrypoint_ids = functional_registry.list_entrypoints()
180+
all_ids = graph_ids + entrypoint_ids
181+
raise ApplicationError(
182+
f"'{graph_id}' not found. Available: {all_ids}",
183+
type=GRAPH_NOT_FOUND_ERROR,
184+
non_retryable=True,
185+
)
186+
187+
188+
def _compile_graph(
189+
graph_id: str,
190+
*,
191+
default_activity_options: dict[str, Any] | None = None,
192+
per_node_activity_options: dict[str, dict[str, Any]] | None = None,
193+
checkpoint: dict | None = None,
194+
) -> TemporalLangGraphRunner:
195+
"""Compile a Graph API graph for Temporal execution."""
138196
# Get graph from registry
139197
pregel = get_graph(graph_id)
140198

@@ -145,11 +203,7 @@ def compile(
145203
def _merge_activity_options(
146204
base: dict[str, Any], override: dict[str, Any]
147205
) -> dict[str, Any]:
148-
"""Merge activity options, with override taking precedence.
149-
150-
Both dicts have structure {"temporal": {...}} from node_activity_options().
151-
We need to merge the inner "temporal" dicts.
152-
"""
206+
"""Merge activity options, with override taking precedence."""
153207
base_temporal = base.get("temporal", {})
154208
override_temporal = override.get("temporal", {})
155209
return {"temporal": {**base_temporal, **override_temporal}}
@@ -186,21 +240,115 @@ def _merge_activity_options(
186240
)
187241

188242

243+
def _compile_entrypoint(
244+
entrypoint_id: str,
245+
*,
246+
default_activity_options: dict[str, Any] | None = None,
247+
task_options: dict[str, dict[str, Any]] | None = None,
248+
) -> TemporalFunctionalRunner:
249+
"""Compile a Functional API entrypoint for Temporal execution."""
250+
from temporalio.contrib.langgraph._functional_registry import (
251+
get_entrypoint_default_options,
252+
get_entrypoint_task_options,
253+
)
254+
255+
# Get plugin-level options from registry
256+
plugin_default_options = get_entrypoint_default_options(entrypoint_id)
257+
plugin_task_options = get_entrypoint_task_options(entrypoint_id)
258+
259+
# Merge default options
260+
merged_default_options: dict[str, Any] | None = None
261+
if plugin_default_options or default_activity_options:
262+
# Unwrap activity_options format if needed
263+
base = plugin_default_options or {}
264+
override = default_activity_options or {}
265+
if "temporal" in base:
266+
base = base.get("temporal", {})
267+
if "temporal" in override:
268+
override = override.get("temporal", {})
269+
merged_default_options = {**base, **override}
270+
271+
# Merge per-task options
272+
merged_task_options: dict[str, dict[str, Any]] | None = None
273+
if plugin_task_options or task_options:
274+
merged_task_options = {}
275+
# Start with plugin options
276+
for task_name, opts in (plugin_task_options or {}).items():
277+
merged_task_options[task_name] = opts
278+
# Merge compile options
279+
if task_options:
280+
for task_name, opts in task_options.items():
281+
if task_name in merged_task_options:
282+
# Merge the options
283+
base = merged_task_options[task_name]
284+
if "temporal" in base:
285+
base = base.get("temporal", {})
286+
override = opts
287+
if "temporal" in override:
288+
override = override.get("temporal", {})
289+
merged_task_options[task_name] = {**base, **override}
290+
else:
291+
# Unwrap if needed
292+
if "temporal" in opts:
293+
merged_task_options[task_name] = opts.get("temporal", {})
294+
else:
295+
merged_task_options[task_name] = opts
296+
297+
# Get default timeout from merged options
298+
default_timeout = timedelta(minutes=5)
299+
if merged_default_options:
300+
if "start_to_close_timeout" in merged_default_options:
301+
default_timeout = merged_default_options["start_to_close_timeout"]
302+
303+
return TemporalFunctionalRunner(
304+
entrypoint_id=entrypoint_id,
305+
default_task_timeout=default_timeout,
306+
task_options=merged_task_options,
307+
)
308+
309+
310+
# Keep compile_functional for backward compatibility (deprecated)
311+
def compile_functional(
312+
entrypoint_id: str,
313+
default_task_timeout: timedelta = timedelta(minutes=5),
314+
task_options: dict[str, dict[str, Any]] | None = None,
315+
) -> TemporalFunctionalRunner:
316+
"""Compile a registered entrypoint for Temporal execution.
317+
318+
.. deprecated::
319+
Use ``compile()`` instead, which auto-detects graph vs entrypoint.
320+
321+
Args:
322+
entrypoint_id: ID of the registered entrypoint.
323+
default_task_timeout: Default timeout for task activities.
324+
task_options: Per-task activity options.
325+
326+
Returns:
327+
A TemporalFunctionalRunner that can be used to invoke the entrypoint.
328+
"""
329+
return TemporalFunctionalRunner(
330+
entrypoint_id=entrypoint_id,
331+
default_task_timeout=default_task_timeout,
332+
task_options=task_options,
333+
)
334+
335+
189336
__all__ = [
190-
# Main API - Graph API
337+
# Main unified API
191338
"activity_options",
192339
"compile",
193340
"LangGraphPlugin",
194341
"StateSnapshot",
195342
"temporal_node_metadata",
343+
# Runner types (for type annotations)
196344
"TemporalLangGraphRunner",
197-
# Main API - Functional API
345+
"TemporalFunctionalRunner",
346+
# Deprecated (kept for backward compatibility)
198347
"compile_functional",
199348
"execute_langgraph_task",
200349
"get_entrypoint",
201350
"LangGraphFunctionalPlugin",
202351
"register_entrypoint",
203-
"TemporalFunctionalRunner",
204352
# Exception types (for catching configuration errors)
205353
"GraphAlreadyRegisteredError",
206354
# Error type constants (for catching ApplicationError.type)

temporalio/contrib/langgraph/_functional_plugin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ class LangGraphFunctionalPlugin(SimplePlugin):
4343
Registers @entrypoint functions, auto-registers the dynamic task activity,
4444
and configures the Pydantic data converter for LangChain messages.
4545
46-
Example:
47-
```python
46+
Example::
47+
4848
from langgraph.func import entrypoint, task
4949
5050
@task
@@ -58,7 +58,6 @@ async def my_entrypoint(x: int) -> int:
5858
plugin = LangGraphFunctionalPlugin(
5959
entrypoints={"my_entrypoint": my_entrypoint},
6060
)
61-
```
6261
"""
6362

6463
def __init__(

temporalio/contrib/langgraph/_functional_runner.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -562,16 +562,6 @@ def compile_functional(
562562
563563
Returns:
564564
A TemporalFunctionalRunner that can be used to invoke the entrypoint.
565-
566-
Example:
567-
```python
568-
@workflow.defn
569-
class MyWorkflow:
570-
@workflow.run
571-
async def run(self, input: str) -> dict:
572-
app = compile_functional("my_entrypoint")
573-
return await app.ainvoke(input)
574-
```
575565
"""
576566
return TemporalFunctionalRunner(
577567
entrypoint_id=entrypoint_id,

temporalio/contrib/langgraph/_models.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,7 @@ def _coerce_state_values(state: dict[str, Any]) -> dict[str, Any]:
6464
because when state passes through Temporal serialization, LangChain message
6565
objects become plain dicts.
6666
67-
Handles nested structures like tool_call_with_context:
68-
{
69-
"__type": "tool_call_with_context",
70-
"tool_call": {...},
71-
"state": {"messages": [...]} # nested messages are coerced
72-
}
67+
Handles nested structures like tool_call_with_context with nested messages.
7368
"""
7469
return {key: _coerce_value(value) for key, value in state.items()}
7570

0 commit comments

Comments
 (0)