From 666b01ebe76f1fa82215004220d90ec150d759a5 Mon Sep 17 00:00:00 2001 From: Aditya Mehra Date: Wed, 11 Feb 2026 09:10:51 -0800 Subject: [PATCH] fix(genai-util): propagate workflow context to downstream spans and metrics Add workflow context inheritance in TelemetryHandler and include gen_ai.workflow.name in downstream LLM/tool/embedding/retrieval metric attributes so nested operations consistently carry workflow identity. Co-authored-by: Cursor --- .../util/genai/emitters/metrics.py | 50 ++++++++ .../src/opentelemetry/util/genai/handler.py | 110 +++++++++++++++++ .../tests/test_metrics.py | 116 +++++++++++++++++- 3 files changed, 275 insertions(+), 1 deletion(-) diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/metrics.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/metrics.py index 3f385633..d110a27a 100644 --- a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/metrics.py +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/metrics.py @@ -11,6 +11,7 @@ gen_ai_attributes as GenAI, ) +from ..attributes import GEN_AI_WORKFLOW_NAME from ..instruments import Instruments from ..interfaces import EmitterMeta from ..types import ( @@ -101,6 +102,13 @@ def on_end(self, obj: Any) -> None: ) if llm_invocation.agent_id: metric_attrs[GenAI.GEN_AI_AGENT_ID] = llm_invocation.agent_id + workflow_name = ( + llm_invocation.attributes.get(GEN_AI_WORKFLOW_NAME) + if llm_invocation.attributes + else None + ) + if workflow_name: + metric_attrs[GEN_AI_WORKFLOW_NAME] = workflow_name _record_token_metrics( self._token_histogram, @@ -132,6 +140,13 @@ def on_end(self, obj: Any) -> None: ) if tool_invocation.agent_id: metric_attrs[GenAI.GEN_AI_AGENT_ID] = tool_invocation.agent_id + workflow_name = ( + tool_invocation.attributes.get(GEN_AI_WORKFLOW_NAME) + if tool_invocation.attributes + else None + ) + if workflow_name: + metric_attrs[GEN_AI_WORKFLOW_NAME] = workflow_name _record_duration( self._duration_histogram, @@ -163,6 +178,13 @@ def on_end(self, obj: Any) -> None: metric_attrs[GenAI.GEN_AI_AGENT_ID] = ( embedding_invocation.agent_id ) + workflow_name = ( + embedding_invocation.attributes.get(GEN_AI_WORKFLOW_NAME) + if embedding_invocation.attributes + else None + ) + if workflow_name: + metric_attrs[GEN_AI_WORKFLOW_NAME] = workflow_name _record_duration( self._duration_histogram, @@ -203,6 +225,13 @@ def on_error(self, error: Error, obj: Any) -> None: ) if llm_invocation.agent_id: metric_attrs[GenAI.GEN_AI_AGENT_ID] = llm_invocation.agent_id + workflow_name = ( + llm_invocation.attributes.get(GEN_AI_WORKFLOW_NAME) + if llm_invocation.attributes + else None + ) + if workflow_name: + metric_attrs[GEN_AI_WORKFLOW_NAME] = workflow_name if getattr(error, "type", None) is not None: metric_attrs[ErrorAttributes.ERROR_TYPE] = ( error.type.__qualname__ @@ -228,6 +257,13 @@ def on_error(self, error: Error, obj: Any) -> None: ) if tool_invocation.agent_id: metric_attrs[GenAI.GEN_AI_AGENT_ID] = tool_invocation.agent_id + workflow_name = ( + tool_invocation.attributes.get(GEN_AI_WORKFLOW_NAME) + if tool_invocation.attributes + else None + ) + if workflow_name: + metric_attrs[GEN_AI_WORKFLOW_NAME] = workflow_name if getattr(error, "type", None) is not None: metric_attrs[ErrorAttributes.ERROR_TYPE] = ( error.type.__qualname__ @@ -260,6 +296,13 @@ def on_error(self, error: Error, obj: Any) -> None: metric_attrs[GenAI.GEN_AI_AGENT_ID] = ( embedding_invocation.agent_id ) + workflow_name = ( + embedding_invocation.attributes.get(GEN_AI_WORKFLOW_NAME) + if embedding_invocation.attributes + else None + ) + if workflow_name: + metric_attrs[GEN_AI_WORKFLOW_NAME] = workflow_name if getattr(error, "type", None) is not None: metric_attrs[ErrorAttributes.ERROR_TYPE] = ( error.type.__qualname__ @@ -362,6 +405,13 @@ def _record_retrieval_metrics( metric_attrs[GenAI.GEN_AI_AGENT_NAME] = retrieval.agent_name if retrieval.agent_id: metric_attrs[GenAI.GEN_AI_AGENT_ID] = retrieval.agent_id + workflow_name = ( + retrieval.attributes.get(GEN_AI_WORKFLOW_NAME) + if retrieval.attributes + else None + ) + if workflow_name: + metric_attrs[GEN_AI_WORKFLOW_NAME] = workflow_name # Add error type if present if error is not None and getattr(error, "type", None) is not None: metric_attrs[ErrorAttributes.ERROR_TYPE] = error.type.__qualname__ diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py index 6b669cff..679e72e7 100644 --- a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py @@ -77,6 +77,11 @@ def genai_debug_log(*_args: Any, **_kwargs: Any) -> None: # type: ignore from opentelemetry.util.genai.emitters.configuration import ( build_emitter_pipeline, ) +from opentelemetry.util.genai.attributes import ( + GEN_AI_AGENT_ID, + GEN_AI_AGENT_NAME, + GEN_AI_WORKFLOW_NAME, +) from opentelemetry.util.genai.span_context import ( extract_span_context, span_context_hex_ids, @@ -254,6 +259,8 @@ def _get_eval_histogram(canonical_name: str): self._evaluation_manager = None # Active agent identity stack (name, id) for implicit propagation to nested operations self._agent_context_stack: list[tuple[str, str]] = [] + # Active workflow name stack for implicit propagation to nested operations + self._workflow_context_stack: list[str] = [] # Span registry (run_id -> Span) to allow parenting even after original invocation ended. # We intentionally retain ended parent spans to preserve trace linkage for late children # (e.g., final LLM call after agent/workflow termination). A lightweight size cap can be @@ -340,6 +347,51 @@ def _refresh_capture_content( except Exception: pass + @staticmethod + def _get_current_span_attribute(key: str) -> Optional[Any]: + """Best-effort extraction of an attribute from the active span.""" + try: + current_span = _trace_mod.get_current_span() + except Exception: + return None + if current_span is None: + return None + attributes = getattr(current_span, "attributes", None) + if attributes is None: + attributes = getattr(current_span, "_attributes", None) + if not attributes: + return None + try: + return attributes.get(key) + except Exception: + return None + + def _inherit_parent_context_attributes(self, invocation: GenAI) -> None: + """Propagate agent/workflow identity from active parent span context.""" + if not invocation.agent_name: + parent_agent_name = self._get_current_span_attribute( + GEN_AI_AGENT_NAME + ) + if isinstance(parent_agent_name, str) and parent_agent_name: + invocation.agent_name = parent_agent_name + + if not invocation.agent_id: + parent_agent_id = self._get_current_span_attribute(GEN_AI_AGENT_ID) + if isinstance(parent_agent_id, str) and parent_agent_id: + invocation.agent_id = parent_agent_id + + if GEN_AI_WORKFLOW_NAME not in invocation.attributes: + parent_workflow_name = self._get_current_span_attribute( + GEN_AI_WORKFLOW_NAME + ) + if ( + isinstance(parent_workflow_name, str) + and parent_workflow_name + ): + invocation.attributes[GEN_AI_WORKFLOW_NAME] = ( + parent_workflow_name + ) + def start_llm( self, invocation: LLMInvocation, @@ -357,6 +409,14 @@ def start_llm( invocation.agent_name = top_name if not invocation.agent_id: invocation.agent_id = top_id + if ( + GEN_AI_WORKFLOW_NAME not in invocation.attributes + and self._workflow_context_stack + ): + invocation.attributes[GEN_AI_WORKFLOW_NAME] = ( + self._workflow_context_stack[-1] + ) + self._inherit_parent_context_attributes(invocation) # Start invocation span; tracer context propagation handles parent/child links self._emitter.on_start(invocation) # Register span if created @@ -475,6 +535,14 @@ def start_embedding( invocation.agent_name = top_name if not invocation.agent_id: invocation.agent_id = top_id + if ( + GEN_AI_WORKFLOW_NAME not in invocation.attributes + and self._workflow_context_stack + ): + invocation.attributes[GEN_AI_WORKFLOW_NAME] = ( + self._workflow_context_stack[-1] + ) + self._inherit_parent_context_attributes(invocation) invocation.start_time = time.time() self._emitter.on_start(invocation) span = getattr(invocation, "span", None) @@ -541,6 +609,14 @@ def start_retrieval( invocation.agent_name = top_name if not invocation.agent_id: invocation.agent_id = top_id + if ( + GEN_AI_WORKFLOW_NAME not in invocation.attributes + and self._workflow_context_stack + ): + invocation.attributes[GEN_AI_WORKFLOW_NAME] = ( + self._workflow_context_stack[-1] + ) + self._inherit_parent_context_attributes(invocation) invocation.start_time = time.time() self._emitter.on_start(invocation) span = getattr(invocation, "span", None) @@ -603,6 +679,14 @@ def start_tool_call(self, invocation: ToolCall) -> ToolCall: invocation.agent_name = top_name if not invocation.agent_id: invocation.agent_id = top_id + if ( + GEN_AI_WORKFLOW_NAME not in invocation.attributes + and self._workflow_context_stack + ): + invocation.attributes[GEN_AI_WORKFLOW_NAME] = ( + self._workflow_context_stack[-1] + ) + self._inherit_parent_context_attributes(invocation) self._emitter.on_start(invocation) span = getattr(invocation, "span", None) if span is not None: @@ -643,6 +727,8 @@ def start_workflow(self, workflow: Workflow) -> Workflow: if span is not None: self._span_registry[str(workflow.run_id)] = span self._entity_registry[str(workflow.run_id)] = workflow + if workflow.name: + self._workflow_context_stack.append(workflow.name) return workflow def _handle_evaluation_results( @@ -784,6 +870,14 @@ def stop_workflow(self, workflow: Workflow) -> Workflow: self._meter_provider.force_flush() # type: ignore[attr-defined] except Exception: pass + try: + if ( + self._workflow_context_stack + and self._workflow_context_stack[-1] == workflow.name + ): + self._workflow_context_stack.pop() + except Exception: + pass return workflow def fail_workflow(self, workflow: Workflow, error: Error) -> Workflow: @@ -800,6 +894,14 @@ def fail_workflow(self, workflow: Workflow, error: Error) -> Workflow: self._meter_provider.force_flush() # type: ignore[attr-defined] except Exception: pass + try: + if ( + self._workflow_context_stack + and self._workflow_context_stack[-1] == workflow.name + ): + self._workflow_context_stack.pop() + except Exception: + pass return workflow # Agent lifecycle ----------------------------------------------------- @@ -808,6 +910,14 @@ def start_agent( ) -> AgentCreation | AgentInvocation: """Start an agent operation (create or invoke) and create a pending span entry.""" self._refresh_capture_content() + if ( + GEN_AI_WORKFLOW_NAME not in agent.attributes + and self._workflow_context_stack + ): + agent.attributes[GEN_AI_WORKFLOW_NAME] = ( + self._workflow_context_stack[-1] + ) + self._inherit_parent_context_attributes(agent) self._emitter.on_start(agent) span = getattr(agent, "span", None) if span is not None: diff --git a/util/opentelemetry-util-genai/tests/test_metrics.py b/util/opentelemetry-util-genai/tests/test_metrics.py index 686f0e8b..2cd45709 100644 --- a/util/opentelemetry-util-genai/tests/test_metrics.py +++ b/util/opentelemetry-util-genai/tests/test_metrics.py @@ -23,7 +23,10 @@ OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT_MODE, OTEL_INSTRUMENTATION_GENAI_EMITTERS, ) -from opentelemetry.util.genai.handler import get_telemetry_handler +from opentelemetry.util.genai.handler import ( + TelemetryHandler, + get_telemetry_handler, +) from opentelemetry.util.genai.types import ( AgentInvocation, Error, @@ -31,6 +34,7 @@ LLMInvocation, OutputMessage, Text, + Workflow, ) STABILITY_EXPERIMENTAL: dict[str, str] = {} @@ -403,6 +407,116 @@ def test_llm_metrics_inherit_agent_identity_from_context(self): "Expected metrics to inherit agent identity from active agent context", ) + def test_cross_handler_parent_context_propagates_agent_and_workflow(self): + env = { + **STABILITY_EXPERIMENTAL, + OTEL_INSTRUMENTATION_GENAI_EMITTERS: "span_metric", + OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT: "true", + OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT_MODE: "SPAN_ONLY", + } + with patch.dict(os.environ, env, clear=False): + if hasattr(get_telemetry_handler, "_default_handler"): + delattr(get_telemetry_handler, "_default_handler") + parent_handler = TelemetryHandler( + tracer_provider=self.tracer_provider, + meter_provider=self.meter_provider, + ) + child_handler = get_telemetry_handler( + tracer_provider=self.tracer_provider, + meter_provider=self.meter_provider, + ) + + workflow = Workflow( + name="crew_workflow", + workflow_type="crewai.crew", + framework="crewai", + system="crewai", + ) + parent_handler.start_workflow(workflow) + agent = AgentInvocation( + name="crew_agent", + model="agent-model", + framework="crewai", + system="crewai", + ) + parent_handler.start_agent(agent) + + inv = LLMInvocation( + request_model="m3", + input_messages=[ + InputMessage(role="user", parts=[Text(content="hello")]) + ], + ) + child_handler.start_llm(inv) + time.sleep(0.01) + inv.output_messages = [ + OutputMessage( + role="assistant", + parts=[Text(content="hi")], + finish_reason="stop", + ) + ] + inv.input_tokens = 3 + inv.output_tokens = 4 + child_handler.stop_llm(inv) + + parent_handler.stop_agent(agent) + parent_handler.stop_workflow(workflow) + + try: + self.meter_provider.force_flush() + except Exception: + pass + self.metric_reader.collect() + + self.assertEqual(inv.agent_name, "crew_agent") + self.assertEqual(inv.agent_id, str(agent.run_id)) + self.assertEqual( + inv.attributes.get("gen_ai.workflow.name"), "crew_workflow" + ) + + spans = self.span_exporter.get_finished_spans() + llm_span = next((s for s in spans if s.name == "chat m3"), None) + self.assertIsNotNone(llm_span, "Expected child LLM span to be emitted") + attrs = llm_span.attributes if llm_span is not None else {} + self.assertEqual(attrs.get("gen_ai.agent.name"), "crew_agent") + self.assertEqual(attrs.get("gen_ai.agent.id"), str(agent.run_id)) + self.assertEqual(attrs.get("gen_ai.workflow.name"), "crew_workflow") + + metrics_list = self._collect_metrics() + saw_duration = False + saw_tokens = False + for metric in metrics_list: + if metric.name not in ( + "gen_ai.client.token.usage", + "gen_ai.client.operation.duration", + ): + continue + data = getattr(metric, "data", None) + if not data: + continue + for dp in getattr(data, "data_points", []) or []: + metric_attrs = getattr(dp, "attributes", {}) or {} + if ( + metric_attrs.get("gen_ai.agent.name") == "crew_agent" + and metric_attrs.get("gen_ai.agent.id") + == str(agent.run_id) + and metric_attrs.get("gen_ai.workflow.name") + == "crew_workflow" + ): + if metric.name == "gen_ai.client.token.usage": + saw_tokens = True + else: + saw_duration = True + self.assertTrue( + saw_duration, + "Expected duration metric to include inherited agent/workflow context", + ) + self.assertTrue( + saw_tokens, + "Expected token usage metric to include inherited agent/workflow context", + ) + def test_llm_duration_metric_includes_error_type_on_failure(self): self._invoke_failure("span_metric") metrics_list = self._collect_metrics()