diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cc9ae6..38c8d2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ ADDED - Added `AsyncTaskHubGrpcClient` for asyncio-based applications using `grpc.aio` - Added `DefaultAsyncClientInterceptorImpl` for async gRPC metadata interceptors - Added `get_async_grpc_channel` helper for creating async gRPC channels +- Improved distributed tracing support with full span coverage for orchestrations, activities, sub-orchestrations, timers, and events CHANGED diff --git a/durabletask/client.py b/durabletask/client.py index 3a7c90f..aa8ab55 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import logging +import uuid from dataclasses import dataclass from datetime import datetime from enum import Enum @@ -16,6 +17,7 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared +import durabletask.internal.tracing as tracing from durabletask import task from durabletask.internal.client_helpers import ( build_query_entities_req, @@ -176,14 +178,28 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu tags: Optional[dict[str, str]] = None, version: Optional[str] = None) -> str: - req = build_schedule_new_orchestration_req( - orchestrator, input=input, instance_id=instance_id, start_at=start_at, - reuse_id_policy=reuse_id_policy, tags=tags, - version=version if version else self.default_version) - - self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.") - res: pb.CreateInstanceResponse = self._stub.StartInstance(req) - return res.instanceId + name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) + resolved_instance_id = instance_id if instance_id else uuid.uuid4().hex + resolved_version = version if version else self.default_version + + with tracing.start_create_orchestration_span( + name, resolved_instance_id, version=resolved_version, + ): + req = build_schedule_new_orchestration_req( + orchestrator, input=input, instance_id=instance_id, start_at=start_at, + reuse_id_policy=reuse_id_policy, tags=tags, + version=version if version else self.default_version) + + # Inject the active PRODUCER span context into the request so the sidecar + # stores it in the executionStarted event and the worker can parent all + # orchestration/activity/timer spans under this trace. + parent_trace_ctx = tracing.get_current_trace_context() + if parent_trace_ctx is not None: + req.parentTraceContext.CopyFrom(parent_trace_ctx) + + self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.") + res: pb.CreateInstanceResponse = self._stub.StartInstance(req) + return res.instanceId def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) @@ -245,10 +261,10 @@ def wait_for_orchestration_completion(self, instance_id: str, *, def raise_orchestration_event(self, instance_id: str, event_name: str, *, data: Optional[Any] = None) -> None: - req = build_raise_event_req(instance_id, event_name, data) - - self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") - self._stub.RaiseEvent(req) + with tracing.start_raise_event_span(event_name, instance_id): + req = build_raise_event_req(instance_id, event_name, data) + self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") + self._stub.RaiseEvent(req) def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, @@ -418,14 +434,25 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator tags: Optional[dict[str, str]] = None, version: Optional[str] = None) -> str: - req = build_schedule_new_orchestration_req( - orchestrator, input=input, instance_id=instance_id, start_at=start_at, - reuse_id_policy=reuse_id_policy, tags=tags, - version=version if version else self.default_version) + name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) + resolved_instance_id = instance_id if instance_id else uuid.uuid4().hex + resolved_version = version if version else self.default_version - self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.") - res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) - return res.instanceId + with tracing.start_create_orchestration_span( + name, resolved_instance_id, version=resolved_version, + ): + req = build_schedule_new_orchestration_req( + orchestrator, input=input, instance_id=instance_id, start_at=start_at, + reuse_id_policy=reuse_id_policy, tags=tags, + version=version if version else self.default_version) + + parent_trace_ctx = tracing.get_current_trace_context() + if parent_trace_ctx is not None: + req.parentTraceContext.CopyFrom(parent_trace_ctx) + + self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.") + res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) + return res.instanceId async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: @@ -487,10 +514,10 @@ async def wait_for_orchestration_completion(self, instance_id: str, *, async def raise_orchestration_event(self, instance_id: str, event_name: str, *, data: Optional[Any] = None) -> None: - req = build_raise_event_req(instance_id, event_name, data) - - self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") - await self._stub.RaiseEvent(req) + with tracing.start_raise_event_span(event_name, instance_id): + req = build_raise_event_req(instance_id, event_name, data) + self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") + await self._stub.RaiseEvent(req) async def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 4720046..03da314 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -27,7 +27,8 @@ def new_orchestrator_completed_event() -> pb.HistoryEvent: def new_execution_started_event(name: str, instance_id: str, encoded_input: Optional[str] = None, tags: Optional[dict[str, str]] = None, - version: Optional[str] = None) -> pb.HistoryEvent: + version: Optional[str] = None, + parent_trace_context: Optional[pb.TraceContext] = None) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), @@ -36,7 +37,8 @@ def new_execution_started_event(name: str, instance_id: str, encoded_input: Opti version=get_string_value(version), input=get_string_value(encoded_input), orchestrationInstance=pb.OrchestrationInstance(instanceId=instance_id), - tags=tags)) + tags=tags, + parentTraceContext=parent_trace_context)) def new_timer_created_event(timer_id: int, fire_at: datetime) -> pb.HistoryEvent: @@ -223,11 +225,13 @@ def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], - tags: Optional[dict[str, str]]) -> pb.OrchestratorAction: + tags: Optional[dict[str, str]], + parent_trace_context: Optional[pb.TraceContext] = None) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, scheduleTask=pb.ScheduleTaskAction( name=name, input=get_string_value(encoded_input), - tags=tags + tags=tags, + parentTraceContext=parent_trace_context, )) @@ -302,12 +306,14 @@ def new_create_sub_orchestration_action( name: str, instance_id: Optional[str], encoded_input: Optional[str], - version: Optional[str]) -> pb.OrchestratorAction: + version: Optional[str], + parent_trace_context: Optional[pb.TraceContext] = None) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction( name=name, instanceId=instance_id, input=get_string_value(encoded_input), - version=get_string_value(version) + version=get_string_value(version), + parentTraceContext=parent_trace_context, )) diff --git a/durabletask/internal/tracing.py b/durabletask/internal/tracing.py new file mode 100644 index 0000000..b3b6092 --- /dev/null +++ b/durabletask/internal/tracing.py @@ -0,0 +1,863 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""OpenTelemetry distributed tracing utilities for the Durable Task SDK. + +This module provides helpers for propagating W3C Trace Context between +orchestrations, activities, sub-orchestrations, and entities via the +``TraceContext`` protobuf message carried over gRPC. + +OpenTelemetry is an **optional** dependency. When the ``opentelemetry-api`` +package is not installed every helper gracefully degrades to a no-op so +that the rest of the SDK continues to work without any tracing overhead. +""" + +from __future__ import annotations + +import logging +import random +import time +from contextlib import contextmanager +from datetime import datetime +from typing import Any, Optional + +from google.protobuf import timestamp_pb2, wrappers_pb2 + +import durabletask.internal.orchestrator_service_pb2 as pb + +logger = logging.getLogger("durabletask-tracing") + +# --------------------------------------------------------------------------- +# Lazy / optional OpenTelemetry imports +# --------------------------------------------------------------------------- +try: + from opentelemetry import context as otel_context + from opentelemetry import trace + from opentelemetry.trace import ( + SpanKind, # type: ignore[no-redef] + StatusCode, # type: ignore[no-redef] + ) + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + + _OTEL_AVAILABLE = True +except ImportError: # pragma: no cover + _OTEL_AVAILABLE = False + # Provide stub for SpanKind so callers can reference tracing.SpanKind + # without guarding every reference with OTEL_AVAILABLE checks. + + class SpanKind: # type: ignore[no-redef] + INTERNAL: Any = None + CLIENT: Any = None + SERVER: Any = None + PRODUCER: Any = None + CONSUMER: Any = None + + class StatusCode: # type: ignore[no-redef] + OK: Any = None + ERROR: Any = None + UNSET: Any = None + +# Re-export so callers can check without importing opentelemetry themselves. +OTEL_AVAILABLE = _OTEL_AVAILABLE + +# The instrumentation scope name used when creating spans. +_TRACER_NAME = "durabletask" + + +# --------------------------------------------------------------------------- +# Span attribute keys (mirrors Schema.cs from .NET SDK) +# --------------------------------------------------------------------------- + +ATTR_TASK_TYPE = "durabletask.type" +ATTR_TASK_NAME = "durabletask.task.name" +ATTR_TASK_VERSION = "durabletask.task.version" +ATTR_TASK_INSTANCE_ID = "durabletask.task.instance_id" +ATTR_TASK_STATUS = "durabletask.task.status" +ATTR_TASK_TASK_ID = "durabletask.task.task_id" +ATTR_EVENT_TARGET_INSTANCE_ID = "durabletask.event.target_instance_id" +ATTR_FIRE_AT = "durabletask.fire_at" + +# Task type values (used in span names and as attribute values) +TASK_TYPE_ORCHESTRATION = "orchestration" +TASK_TYPE_TIMER = "timer" +TASK_TYPE_EVENT = "event" + +# Span name type prefixes (composite types) +SPAN_TYPE_CREATE_ORCHESTRATION = "create_orchestration" +SPAN_TYPE_ORCHESTRATION_EVENT = "orchestration_event" + +# Task status values +TASK_STATUS_COMPLETED = "Completed" +TASK_STATUS_FAILED = "Failed" + +# W3C Trace Context carrier keys +CARRIER_KEY_TRACEPARENT = "traceparent" +CARRIER_KEY_TRACESTATE = "tracestate" + + +# --------------------------------------------------------------------------- +# Span name helpers (mirrors TraceActivityConstants / TraceHelper naming) +# --------------------------------------------------------------------------- + +def create_span_name( + span_type: str, task_name: str, version: Optional[str] = None, +) -> str: + """Build a span name with optional version suffix. + + Examples:: + + create_span_name("orchestration", "MyOrch") -> "orchestration:MyOrch" + create_span_name("activity", "Say", "1.0") -> "activity:Say@(1.0)" + """ + if version: + return f"{span_type}:{task_name}@({version})" + return f"{span_type}:{task_name}" + + +def create_timer_span_name(orchestration_name: str) -> str: + """Build a timer span name: ``orchestration::timer``.""" + return f"{TASK_TYPE_ORCHESTRATION}:{orchestration_name}:{TASK_TYPE_TIMER}" + + +# --------------------------------------------------------------------------- +# Public helpers – extracting / injecting trace context +# --------------------------------------------------------------------------- + + +def _trace_context_from_carrier(carrier: dict[str, str]) -> Optional[pb.TraceContext]: + """Build a ``TraceContext`` protobuf from a W3C propagation carrier. + + Returns ``None`` when the carrier does not contain a valid + ``traceparent`` header. + """ + traceparent = carrier.get(CARRIER_KEY_TRACEPARENT) + if not traceparent: + return None + + tracestate = carrier.get(CARRIER_KEY_TRACESTATE) + # Format: 00--- + parts = traceparent.split("-") + span_id = parts[2] if len(parts) >= 4 else "" + + return pb.TraceContext( + traceParent=traceparent, + spanID=span_id, + traceState=wrappers_pb2.StringValue(value=tracestate) + if tracestate else None, + ) + + +def _parse_traceparent(traceparent: str) -> Optional[tuple[int, int, int]]: + """Parse a W3C traceparent string into ``(trace_id, span_id, trace_flags)``. + + Returns ``None`` when the string is not a valid traceparent. + """ + parts = traceparent.split("-") + if len(parts) < 4: + return None + try: + trace_id = int(parts[1], 16) + span_id = int(parts[2], 16) + flags = int(parts[3], 16) + if trace_id == 0 or span_id == 0: + return None + return trace_id, span_id, flags + except ValueError: + return None + + +def get_current_trace_context() -> Optional[pb.TraceContext]: + """Capture the current OpenTelemetry span context as a protobuf ``TraceContext``. + + Returns ``None`` when OpenTelemetry is not installed or there is no + active span. + """ + if not _OTEL_AVAILABLE: + return None + + propagator = TraceContextTextMapPropagator() + carrier: dict[str, str] = {} + propagator.inject(carrier) + return _trace_context_from_carrier(carrier) + + +def extract_trace_context(proto_ctx: Optional[pb.TraceContext]) -> Optional[Any]: + """Convert a protobuf ``TraceContext`` into an OpenTelemetry ``Context``. + + Returns ``None`` when OpenTelemetry is not installed or the supplied + context is empty / ``None``. + """ + if not _OTEL_AVAILABLE or proto_ctx is None: + return None + + traceparent = proto_ctx.traceParent + if not traceparent: + return None + + carrier: dict[str, str] = {CARRIER_KEY_TRACEPARENT: traceparent} + if proto_ctx.HasField("traceState") and proto_ctx.traceState.value: + carrier[CARRIER_KEY_TRACESTATE] = proto_ctx.traceState.value + + propagator = TraceContextTextMapPropagator() + ctx = propagator.extract(carrier) + return ctx + + +@contextmanager +def start_span( + name: str, + trace_context: Optional[pb.TraceContext] = None, + kind: Any = None, + attributes: Optional[dict[str, str]] = None, +): + """Context manager that starts an OpenTelemetry span linked to a parent trace context. + + If OpenTelemetry is not installed, the block executes without tracing. + + Parameters + ---------- + name: + Human-readable span name (e.g. ``"activity:say_hello"``). + trace_context: + The protobuf ``TraceContext`` received from the sidecar. When + provided the new span will be created as a **child** of this + context. + kind: + The ``SpanKind`` for the new span. Defaults to ``SpanKind.INTERNAL``. + attributes: + Optional dictionary of span attributes. + """ + if not _OTEL_AVAILABLE: + yield None + return + + parent_ctx = extract_trace_context(trace_context) + + if kind is None: + kind = SpanKind.INTERNAL + + tracer = trace.get_tracer(_TRACER_NAME) + + if parent_ctx is not None: + token = otel_context.attach(parent_ctx) + try: + with tracer.start_as_current_span( + name, kind=kind, attributes=attributes + ) as span: + yield span + finally: + otel_context.detach(token) + else: + with tracer.start_as_current_span( + name, kind=kind, attributes=attributes + ) as span: + yield span + + +def set_span_error(span: Any, ex: Exception) -> None: + """Record an exception on the given span (if tracing is available).""" + if not _OTEL_AVAILABLE or span is None: + return + span.set_status(StatusCode.ERROR, str(ex)) + span.record_exception(ex) + + +# --------------------------------------------------------------------------- +# Orchestration-level span helpers +# --------------------------------------------------------------------------- + +def emit_orchestration_span( + name: str, + instance_id: str, + start_time_ns: Optional[int], + is_failed: bool, + failure_details: Any = None, + parent_trace_context: Optional[pb.TraceContext] = None, + orchestration_trace_context: Optional[pb.TraceContext] = None, + version: Optional[str] = None, +) -> None: + """Emit a SERVER span for a completed orchestration (create-and-end). + + The span is created with *start_time_ns* as its start time and + ended immediately. This avoids storing spans across dispatches. + + When *orchestration_trace_context* is provided, the span is emitted + with the pre-determined span ID from that context using the deferred + ``ReadableSpan`` approach. This ensures child spans (activities, + timers, sub-orchestrations) that were created with this context as + their parent are correctly nested under the orchestration span. + + Falls back to ``tracer.start_span()`` (which generates its own + span ID) when *orchestration_trace_context* is ``None``. + """ + if not _OTEL_AVAILABLE: + return + + span_name = create_span_name(TASK_TYPE_ORCHESTRATION, name, version) + + attrs: dict[str, str] = { + ATTR_TASK_TYPE: TASK_TYPE_ORCHESTRATION, + ATTR_TASK_NAME: name, + ATTR_TASK_INSTANCE_ID: instance_id, + } + if version: + attrs[ATTR_TASK_VERSION] = version + + # When we have a pre-determined orchestration span ID, use the + # deferred ReadableSpan approach so child spans match. + if orchestration_trace_context is not None: + if _emit_orchestration_span_deferred( + span_name, attrs, start_time_ns, is_failed, + failure_details, parent_trace_context, + orchestration_trace_context, + ): + return + + # Fallback: use the normal tracer.start_span() approach. + tracer = trace.get_tracer(_TRACER_NAME) + parent_ctx = extract_trace_context(parent_trace_context) + + token = None + if parent_ctx is not None: + token = otel_context.attach(parent_ctx) + + try: + span = tracer.start_span( + span_name, + kind=SpanKind.SERVER, + attributes=attrs, + start_time=start_time_ns, + ) + + if is_failed: + msg = "" + if failure_details is not None: + msg = ( + str(failure_details.errorMessage) + if hasattr(failure_details, 'errorMessage') + else str(failure_details) + ) + span.set_status(StatusCode.ERROR, msg) + span.set_attribute(ATTR_TASK_STATUS, TASK_STATUS_FAILED) + else: + span.set_attribute(ATTR_TASK_STATUS, TASK_STATUS_COMPLETED) + + span.end() + finally: + if token is not None: + otel_context.detach(token) + + +def _emit_orchestration_span_deferred( + span_name: str, + attrs: dict[str, str], + start_time_ns: Optional[int], + is_failed: bool, + failure_details: Any, + parent_trace_context: Optional[pb.TraceContext], + orchestration_trace_context: pb.TraceContext, +) -> bool: + """Emit an orchestration SERVER span with a pre-determined span ID. + + Uses the same ``ReadableSpan`` approach as :func:`emit_client_span` + to reconstruct the span with the span ID from + *orchestration_trace_context*. The span is parented under the + PRODUCER span identified by *parent_trace_context*. + + Returns ``True`` when the deferred span was emitted successfully, + ``False`` when it could not be emitted (caller should fall back). + """ + try: + from opentelemetry.sdk.trace import ReadableSpan as SdkReadableSpan + from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider + from opentelemetry.sdk.util.instrumentation import InstrumentationScope + from opentelemetry.trace.status import Status + except (ImportError, AttributeError): + return False + + orch_parsed = _parse_traceparent(orchestration_trace_context.traceParent) + if orch_parsed is None: + return False + trace_id_val, span_id_val, flags_val = orch_parsed + + # Parent is the PRODUCER span + parent_span_id_val = None + if parent_trace_context is not None: + parent_parsed = _parse_traceparent(parent_trace_context.traceParent) + if parent_parsed is not None: + parent_span_id_val = parent_parsed[1] + + span_context = trace.SpanContext( + trace_id=trace_id_val, + span_id=span_id_val, + is_remote=False, + trace_flags=trace.TraceFlags(flags_val), + ) + + parent_context = None + if parent_span_id_val is not None: + parent_context = trace.SpanContext( + trace_id=trace_id_val, + span_id=parent_span_id_val, + is_remote=True, + trace_flags=trace.TraceFlags(flags_val), + ) + + provider = trace.get_tracer_provider() + if not isinstance(provider, SdkTracerProvider): + return False + + if start_time_ns is None: + start_time_ns = time.time_ns() + end_time_ns = time.time_ns() + + try: + if is_failed: + msg = "" + if failure_details is not None: + msg = ( + str(failure_details.errorMessage) + if hasattr(failure_details, 'errorMessage') + else str(failure_details) + ) + status = Status(StatusCode.ERROR, msg) + attrs[ATTR_TASK_STATUS] = TASK_STATUS_FAILED + else: + status = Status(StatusCode.UNSET) + attrs[ATTR_TASK_STATUS] = TASK_STATUS_COMPLETED + + readable_span = SdkReadableSpan( + name=span_name, + context=span_context, + parent=parent_context, + kind=SpanKind.SERVER, + attributes=attrs, + start_time=start_time_ns, + end_time=end_time_ns, + resource=provider.resource, + instrumentation_scope=InstrumentationScope(_TRACER_NAME), + status=status, + ) + + processor = getattr(provider, '_active_span_processor', None) + if processor is not None: + processor.on_end(readable_span) + return True + return False + except Exception: + logger.debug( + "Failed to emit deferred orchestration SERVER span (OTel SDK " + "internals may have changed). Falling back to normal span.", + exc_info=True, + ) + return False + + +# --------------------------------------------------------------------------- +# CLIENT span helpers (deferred create / emit) +# +# Deferred CLIENT spans use opentelemetry-sdk internals (ReadableSpan +# constructor, TracerProvider._active_span_processor) to emit spans +# with pre-determined span IDs. _is_deferred_span_capable() validates +# these internals are accessible; if not, the SDK falls back to +# parenting SERVER spans directly under the orchestration span. +# --------------------------------------------------------------------------- + + +def _is_deferred_span_capable() -> bool: + """Check whether the OTel SDK internals needed for deferred CLIENT spans are available. + + Returns ``False`` when: + * ``opentelemetry-sdk`` is not installed (only the API is present), + * the active ``TracerProvider`` is not the SDK implementation, or + * the private ``_active_span_processor`` attribute is missing. + + Callers should skip CLIENT span generation when this returns ``False`` + so that the orchestration's own trace context is used as the parent + for downstream SERVER spans instead. + """ + try: + from opentelemetry.sdk.trace import ReadableSpan # noqa: F401 + from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider + except (ImportError, AttributeError): + return False + + provider = trace.get_tracer_provider() + if not isinstance(provider, SdkTracerProvider): + return False + if not hasattr(provider, '_active_span_processor'): + return False + + return True + + +def generate_client_trace_context( + parent_trace_context: Optional[pb.TraceContext] = None, +) -> Optional[pb.TraceContext]: + """Generate a trace context for a deferred CLIENT span. + + Creates a new span ID and builds a W3C traceparent string **without** + creating an OpenTelemetry ``Span``. The actual CLIENT span will be + reconstructed later with proper timestamps via :func:`emit_client_span`. + + Returns ``None`` when OpenTelemetry is not available, the SDK + internals required for deferred span emission are not accessible, + or *parent_trace_context* is empty / invalid. When ``None`` is + returned the caller should fall back to the orchestration's own + trace context as the parent for downstream spans. + """ + if not _OTEL_AVAILABLE: + return None + if parent_trace_context is None: + return None + + # Pre-flight: ensure the SDK internals we need in emit_client_span() + # are accessible. If not, return None so the caller falls back to + # the orchestration trace context — the SERVER span will be parented + # directly under the orchestration span instead of a CLIENT span. + if not _is_deferred_span_capable(): + return None + + parsed = _parse_traceparent(parent_trace_context.traceParent) + if parsed is None: + return None + + trace_id, _parent_span_id, flags = parsed + + # Generate a new span ID for the CLIENT span + span_id = random.getrandbits(64) + while span_id == 0: + span_id = random.getrandbits(64) + + traceparent = f"00-{trace_id:032x}-{span_id:016x}-{flags:02x}" + return pb.TraceContext( + traceParent=traceparent, + spanID=format(span_id, '016x'), + ) + + +def emit_client_span( + task_type: str, + name: str, + instance_id: str, + task_id: int, + client_trace_context: pb.TraceContext, + parent_trace_context: Optional[pb.TraceContext] = None, + start_time_ns: Optional[int] = None, + end_time_ns: Optional[int] = None, + is_error: bool = False, + error_message: Optional[str] = None, + version: Optional[str] = None, +) -> None: + """Emit a CLIENT span with a specific span ID reconstructed from history. + + The span ID is extracted from *client_trace_context* so that it matches + the one previously propagated to the downstream SERVER span. A + ``ReadableSpan`` is constructed and fed directly to the active span + processor for export. + + *start_time_ns* and *end_time_ns* should come from the history event + timestamps (``taskScheduled`` / ``taskCompleted``, etc.). + + If the SDK internals are unavailable or have changed, this function + logs a debug message and returns silently. The CLIENT span is + simply omitted from the trace; the activity/sub-orchestration + SERVER span remains connected to the orchestration span via the + fallback parenting established in :func:`generate_client_trace_context`. + """ + if not _OTEL_AVAILABLE: + return + + # SDK-internal imports — see _is_deferred_span_capable() for details. + try: + from opentelemetry.sdk.trace import ReadableSpan as SdkReadableSpan + from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider + from opentelemetry.sdk.util.instrumentation import InstrumentationScope + from opentelemetry.trace.status import Status + except (ImportError, AttributeError): + return + + client_parsed = _parse_traceparent(client_trace_context.traceParent) + if client_parsed is None: + return + trace_id_val, span_id_val, flags_val = client_parsed + + # Determine parent span ID from the orchestration's trace context + parent_span_id_val = None + if parent_trace_context is not None: + parent_parsed = _parse_traceparent(parent_trace_context.traceParent) + if parent_parsed is not None: + parent_span_id_val = parent_parsed[1] + + span_context = trace.SpanContext( + trace_id=trace_id_val, + span_id=span_id_val, + is_remote=False, + trace_flags=trace.TraceFlags(flags_val), + ) + + parent_context = None + if parent_span_id_val is not None: + parent_context = trace.SpanContext( + trace_id=trace_id_val, + span_id=parent_span_id_val, + is_remote=True, + trace_flags=trace.TraceFlags(flags_val), + ) + + span_name = create_span_name(task_type, name, version) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: task_type, + ATTR_TASK_NAME: name, + ATTR_TASK_INSTANCE_ID: instance_id, + ATTR_TASK_TASK_ID: str(task_id), + } + if version: + attrs[ATTR_TASK_VERSION] = version + + provider = trace.get_tracer_provider() + if not isinstance(provider, SdkTracerProvider): + return + + if start_time_ns is None or end_time_ns is None: + now_ns = time.time_ns() + if start_time_ns is None: + start_time_ns = now_ns + if end_time_ns is None: + end_time_ns = now_ns + + # Construct a ReadableSpan with the pre-determined span ID and feed + # it to the span processor for export. + try: + status = Status( + StatusCode.ERROR, error_message or "" + ) if is_error else Status(StatusCode.UNSET) + + readable_span = SdkReadableSpan( + name=span_name, + context=span_context, + parent=parent_context, + kind=SpanKind.CLIENT, + attributes=attrs, + start_time=start_time_ns, + end_time=end_time_ns, + resource=provider.resource, + instrumentation_scope=InstrumentationScope(_TRACER_NAME), + status=status, + ) + + processor = getattr(provider, '_active_span_processor', None) + if processor is not None: + processor.on_end(readable_span) + except Exception: + logger.debug( + "Failed to emit deferred CLIENT span (OTel SDK internals may " + "have changed). The trace will be missing the CLIENT span " + "layer but remains connected via the orchestration span.", + exc_info=True, + ) + + +def emit_timer_span( + orchestration_name: str, + instance_id: str, + timer_id: int, + fire_at: datetime, + scheduled_time_ns: Optional[int] = None, + parent_trace_context: Optional[pb.TraceContext] = None, +) -> None: + """Emit an Internal span for a timer (emit-and-close pattern). + + When *scheduled_time_ns* is provided the span start time is backdated + to when the timer was originally created, so the span duration covers + the full wait period. + + When *parent_trace_context* is provided the span is created as a + child of that context; otherwise it inherits the ambient context. + """ + if not _OTEL_AVAILABLE: + return + + span_name = create_timer_span_name(orchestration_name) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: TASK_TYPE_TIMER, + ATTR_TASK_NAME: orchestration_name, + ATTR_TASK_INSTANCE_ID: instance_id, + ATTR_TASK_TASK_ID: str(timer_id), + ATTR_FIRE_AT: fire_at.isoformat(), + } + + tracer = trace.get_tracer(_TRACER_NAME) + parent_ctx = extract_trace_context(parent_trace_context) + + token = None + if parent_ctx is not None: + token = otel_context.attach(parent_ctx) + + try: + span = tracer.start_span( + span_name, + kind=SpanKind.INTERNAL, + attributes=attrs, + start_time=scheduled_time_ns, + ) + span.end() + finally: + if token is not None: + otel_context.detach(token) + + +def emit_event_raised_span( + event_name: str, + instance_id: str, + target_instance_id: Optional[str] = None, + parent_trace_context: Optional[pb.TraceContext] = None, +) -> None: + """Emit a Producer span for an event raised from the orchestration. + + When *parent_trace_context* is provided the span is created as a + child of that context; otherwise it inherits the ambient context. + """ + if not _OTEL_AVAILABLE: + return + + span_name = create_span_name(SPAN_TYPE_ORCHESTRATION_EVENT, event_name) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: TASK_TYPE_EVENT, + ATTR_TASK_NAME: event_name, + ATTR_TASK_INSTANCE_ID: instance_id, + } + if target_instance_id: + attrs[ATTR_EVENT_TARGET_INSTANCE_ID] = target_instance_id + + tracer = trace.get_tracer(_TRACER_NAME) + parent_ctx = extract_trace_context(parent_trace_context) + + token = None + if parent_ctx is not None: + token = otel_context.attach(parent_ctx) + + try: + span = tracer.start_span( + span_name, + kind=SpanKind.PRODUCER, + attributes=attrs, + ) + span.end() + finally: + if token is not None: + otel_context.detach(token) + + +# --------------------------------------------------------------------------- +# Client-side Producer span helpers +# --------------------------------------------------------------------------- + +@contextmanager +def start_create_orchestration_span( + name: str, + instance_id: str, + version: Optional[str] = None, +): + """Context manager for a Producer span when scheduling a new orchestration. + + Yields the span; caller should capture the trace context after entering + the span context so it can be injected into the gRPC request. + """ + if not _OTEL_AVAILABLE: + yield None + return + + span_name = create_span_name(SPAN_TYPE_CREATE_ORCHESTRATION, name, version) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: TASK_TYPE_ORCHESTRATION, + ATTR_TASK_NAME: name, + ATTR_TASK_INSTANCE_ID: instance_id, + } + if version: + attrs[ATTR_TASK_VERSION] = version + + tracer = trace.get_tracer(_TRACER_NAME) + with tracer.start_as_current_span( + span_name, + kind=SpanKind.PRODUCER, + attributes=attrs, + ) as span: + yield span + + +@contextmanager +def start_raise_event_span( + event_name: str, + target_instance_id: str, +): + """Context manager for a Producer span when raising an event from the client.""" + if not _OTEL_AVAILABLE: + yield None + return + + span_name = create_span_name(SPAN_TYPE_ORCHESTRATION_EVENT, event_name) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: TASK_TYPE_EVENT, + ATTR_TASK_NAME: event_name, + ATTR_EVENT_TARGET_INSTANCE_ID: target_instance_id, + } + + tracer = trace.get_tracer(_TRACER_NAME) + with tracer.start_as_current_span( + span_name, + kind=SpanKind.PRODUCER, + attributes=attrs, + ) as span: + yield span + + +def reconstruct_trace_context( + parent_trace_context: pb.TraceContext, + span_id: str, +) -> Optional[pb.TraceContext]: + """Reconstruct a ``TraceContext`` with a specific span ID. + + Uses the trace ID and flags from *parent_trace_context* but replaces + the span ID with *span_id*. This is used to reuse a pre-determined + orchestration span ID across replays. + """ + if not _OTEL_AVAILABLE: + return None + + parsed = _parse_traceparent(parent_trace_context.traceParent) + if parsed is None: + return None + trace_id, _, flags = parsed + traceparent = f"00-{trace_id:032x}-{span_id}-{flags:02x}" + return pb.TraceContext( + traceParent=traceparent, + spanID=span_id, + ) + + +def build_orchestration_trace_context( + start_time_ns: Optional[int], + span_id: Optional[str] = None, +) -> Optional[pb.OrchestrationTraceContext]: + """Build an ``OrchestrationTraceContext`` protobuf to return to the sidecar. + + This preserves both the orchestration start time and span ID across + replays so that all dispatches produce a consistent orchestration + SERVER span. + """ + if start_time_ns is None: + return None + + ctx = pb.OrchestrationTraceContext() + + ts = timestamp_pb2.Timestamp() + ts.FromNanoseconds(start_time_ns) + ctx.spanStartTime.CopyFrom(ts) + + if span_id: + ctx.spanID.CopyFrom(wrappers_pb2.StringValue(value=span_id)) + + return ctx diff --git a/durabletask/task.py b/durabletask/task.py index 0ef03da..2375f12 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -6,7 +6,7 @@ import math from abc import ABC, abstractmethod -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock, EntityContext @@ -463,7 +463,7 @@ def compute_next_delay(self) -> Optional[timedelta]: else: backoff_coefficient = self._retry_policy.backoff_coefficient - if datetime.utcnow() < retry_expiration: + if datetime.now(tz=timezone.utc).replace(tzinfo=None) < retry_expiration: next_delay_f = math.pow(backoff_coefficient, self._attempt_count - 1) * self._retry_policy.first_retry_interval.total_seconds() if self._retry_policy.max_retry_interval is not None: diff --git a/durabletask/worker.py b/durabletask/worker.py index 442165d..9c7f2d4 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -7,6 +7,7 @@ import logging import os import random +import time from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone from threading import Event, Thread @@ -32,6 +33,7 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared +import durabletask.internal.tracing as tracing from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl @@ -633,18 +635,82 @@ def _execute_orchestrator( stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): + instance_id = req.instanceId + + # Extract parent trace context from executionStarted event + parent_trace_ctx = None + orchestration_name = "" + for e in list(req.pastEvents) + list(req.newEvents): + if e.HasField("executionStarted"): + orchestration_name = e.executionStarted.name + if e.executionStarted.HasField("parentTraceContext"): + parent_trace_ctx = e.executionStarted.parentTraceContext + break + + # Determine the orchestration start time: reuse persisted value + # from a prior dispatch, or capture a new one. + if (req.HasField("orchestrationTraceContext") and req.orchestrationTraceContext.HasField("spanStartTime")): + start_time_ns = req.orchestrationTraceContext.spanStartTime.ToNanoseconds() + else: + start_time_ns = time.time_ns() + + # Extract persisted orchestration span ID from a prior dispatch + persisted_orch_span_id = None + if (req.HasField("orchestrationTraceContext") and req.orchestrationTraceContext.HasField("spanID") and req.orchestrationTraceContext.spanID.value): + persisted_orch_span_id = req.orchestrationTraceContext.spanID.value + try: - executor = _OrchestrationExecutor(self._registry, self._logger) - result = executor.execute(req.instanceId, req.pastEvents, req.newEvents) + executor = _OrchestrationExecutor( + self._registry, self._logger, + persisted_orch_span_id=persisted_orch_span_id) + result = executor.execute(instance_id, req.pastEvents, req.newEvents) + + # Determine completion status for span + is_complete = False + is_failed = False + failure_details = None + for action in result.actions: + if action.HasField("completeOrchestration"): + is_complete = True + orch_status = action.completeOrchestration.orchestrationStatus + if orch_status == pb.ORCHESTRATION_STATUS_FAILED: + is_failed = True + failure_details = action.completeOrchestration.failureDetails + + if is_complete: + # Orchestration finished — emit a single span covering its lifetime + tracing.emit_orchestration_span( + orchestration_name, + instance_id, + start_time_ns, + is_failed, + failure_details=failure_details, + parent_trace_context=parent_trace_ctx, + orchestration_trace_context=result._orchestration_trace_context, + ) + + # Include the span ID in the orchestration trace context + # so it persists across dispatches. + orch_span_id = None + if result._orchestration_trace_context: + orch_span_id = result._orchestration_trace_context.spanID + orch_trace_ctx = tracing.build_orchestration_trace_context( + start_time_ns, span_id=orch_span_id) + res = pb.OrchestratorResponse( - instanceId=req.instanceId, + instanceId=instance_id, actions=result.actions, customStatus=ph.get_string_value(result.encoded_custom_status), completionToken=completionToken, + orchestrationTraceContext=( + orch_trace_ctx if orch_trace_ctx + else req.orchestrationTraceContext + ), ) except pe.AbandonOrchestrationError: + # Abandoned — no span needed self._logger.info( - f"Abandoning orchestration. InstanceId = '{req.instanceId}'. Completion token = '{completionToken}'" + f"Abandoning orchestration. InstanceId = '{instance_id}'. Completion token = '{completionToken}'" ) stub.AbandonTaskOrchestratorWorkItem( pb.AbandonOrchestrationTaskRequest( @@ -653,8 +719,17 @@ def _execute_orchestrator( ) return except Exception as ex: + # Unhandled error — emit a failed span + tracing.emit_orchestration_span( + orchestration_name, + instance_id, + start_time_ns, + is_failed=True, + failure_details=ex, + parent_trace_context=parent_trace_ctx, + ) self._logger.exception( - f"An error occurred while trying to execute instance '{req.instanceId}': {ex}" + f"An error occurred while trying to execute instance '{instance_id}': {ex}" ) failure_details = ph.new_failure_details(ex) actions = [ @@ -663,7 +738,7 @@ def _execute_orchestrator( ) ] res = pb.OrchestratorResponse( - instanceId=req.instanceId, + instanceId=instance_id, actions=actions, completionToken=completionToken, ) @@ -697,9 +772,24 @@ def _execute_activity( instance_id = req.orchestrationInstance.instanceId try: executor = _ActivityExecutor(self._registry, self._logger) - result = executor.execute( - instance_id, req.name, req.taskId, req.input.value - ) + with tracing.start_span( + tracing.create_span_name("activity", req.name), + trace_context=req.parentTraceContext, + kind=tracing.SpanKind.SERVER, + attributes={ + tracing.ATTR_TASK_TYPE: "activity", + tracing.ATTR_TASK_INSTANCE_ID: instance_id, + tracing.ATTR_TASK_NAME: req.name, + tracing.ATTR_TASK_TASK_ID: str(req.taskId), + }, + ) as span: + try: + result = executor.execute( + instance_id, req.name, req.taskId, req.input.value + ) + except Exception as ex: + tracing.set_span_error(span, ex) + raise res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, @@ -759,30 +849,45 @@ def _execute_entity_batch( operation_result = None - try: - entity_result = executor.execute( - instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value - ) - - entity_result = ph.get_string_value_or_empty(entity_result) - operation_result = pb.OperationResult(success=pb.OperationResultSuccess( - result=entity_result, - startTimeUtc=new_timestamp(start_time), - endTimeUtc=new_timestamp(datetime.now(timezone.utc)) - )) - results.append(operation_result) - - entity_state.commit() - except Exception as ex: - self._logger.exception(ex) - operation_result = pb.OperationResult(failure=pb.OperationResultFailure( - failureDetails=ph.new_failure_details(ex), - startTimeUtc=new_timestamp(start_time), - endTimeUtc=new_timestamp(datetime.now(timezone.utc)) - )) - results.append(operation_result) + # Get the trace context for this operation, if available + op_trace_ctx = operation.traceContext if operation.HasField("traceContext") else None + + with tracing.start_span( + tracing.create_span_name("entity", f"{entity_instance_id.entity}:{operation.operation}"), + trace_context=op_trace_ctx, + kind=tracing.SpanKind.SERVER, + attributes={ + tracing.ATTR_TASK_TYPE: "entity", + tracing.ATTR_TASK_INSTANCE_ID: instance_id, + tracing.ATTR_TASK_NAME: entity_instance_id.entity, + "durabletask.entity.operation": operation.operation, + }, + ) as span: + try: + entity_result = executor.execute( + instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value + ) - entity_state.rollback() + entity_result = ph.get_string_value_or_empty(entity_result) + operation_result = pb.OperationResult(success=pb.OperationResultSuccess( + result=entity_result, + startTimeUtc=new_timestamp(start_time), + endTimeUtc=new_timestamp(datetime.now(timezone.utc)) + )) + results.append(operation_result) + + entity_state.commit() + except Exception as ex: + tracing.set_span_error(span, ex) + self._logger.exception(ex) + operation_result = pb.OperationResult(failure=pb.OperationResultFailure( + failureDetails=ph.new_failure_details(ex), + startTimeUtc=new_timestamp(start_time), + endTimeUtc=new_timestamp(datetime.now(timezone.utc)) + )) + results.append(operation_result) + + entity_state.rollback() batch_result = pb.EntityBatchResult( results=results, @@ -847,6 +952,8 @@ def __init__(self, instance_id: str, registry: _Registry): self._new_input: Optional[Any] = None self._save_events = False self._encoded_custom_status: Optional[str] = None + self._parent_trace_context: Optional[pb.TraceContext] = None + self._orchestration_trace_context: Optional[pb.TraceContext] = None def run(self, generator: Generator[task.Task, Any, Any]): self._generator = generator @@ -1136,15 +1243,38 @@ def call_activity_function_helper( if isinstance(activity_function, str) else task.get_name(activity_function) ) - action = ph.new_schedule_task_action(id, name, encoded_input, tags) + # Generate a trace context for the deferred CLIENT span. + # The actual span is emitted later with proper timestamps + # when the taskCompleted/taskFailed event arrives. + orch_ctx = self._orchestration_trace_context or self._parent_trace_context + parent_ctx = orch_ctx + if not self._is_replaying: + client_ctx = tracing.generate_client_trace_context( + parent_trace_context=orch_ctx) + if client_ctx is not None: + parent_ctx = client_ctx + action = ph.new_schedule_task_action( + id, name, encoded_input, tags, + parent_trace_context=parent_ctx) else: if instance_id is None: # Create a deteministic instance ID based on the parent instance ID instance_id = f"{self.instance_id}:{id:04x}" if not isinstance(activity_function, str): raise ValueError("Orchestrator function name must be a string") + # Generate a trace context for the deferred CLIENT span. + # The actual span is emitted later with proper timestamps + # when the sub-orchestration completes or fails. + orch_ctx = self._orchestration_trace_context or self._parent_trace_context + parent_ctx = orch_ctx + if not self._is_replaying: + client_ctx = tracing.generate_client_trace_context( + parent_trace_context=orch_ctx) + if client_ctx is not None: + parent_ctx = client_ctx action = ph.new_create_sub_orchestration_action( - id, activity_function, instance_id, encoded_input, version + id, activity_function, instance_id, encoded_input, version, + parent_trace_context=parent_ctx ) self._pending_actions[id] = action @@ -1277,22 +1407,39 @@ def new_uuid(self) -> str: class ExecutionResults: actions: list[pb.OrchestratorAction] encoded_custom_status: Optional[str] + _orchestration_trace_context: Optional[pb.TraceContext] def __init__( - self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str] + self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str], + orchestration_trace_context: Optional[pb.TraceContext] = None, ): self.actions = actions self.encoded_custom_status = encoded_custom_status + self._orchestration_trace_context = orchestration_trace_context class _OrchestrationExecutor: _generator: Optional[task.Orchestrator] = None - def __init__(self, registry: _Registry, logger: logging.Logger): + def __init__( + self, + registry: _Registry, + logger: logging.Logger, + persisted_orch_span_id: Optional[str] = None, + ): self._registry = registry self._logger = logger self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] + self._persisted_orch_span_id = persisted_orch_span_id + # Maps timer_id -> (fire_at, created_time_ns) + self._timer_fire_at: dict[int, tuple[datetime, Optional[int]]] = {} + # Maps task_id -> (task_type, name, instance_id, scheduled_ns, + # client_trace_ctx, version) + # Used to reconstruct CLIENT spans with proper timestamps. + self._task_scheduled_info: dict[ + int, tuple[str, str, str, Optional[int], pb.TraceContext, Optional[str]] + ] = {} def execute( self, @@ -1304,6 +1451,7 @@ def execute( orchestration_started_events = [e for e in old_events if e.HasField("executionStarted")] if len(orchestration_started_events) >= 1: orchestration_name = orchestration_started_events[0].executionStarted.name + self._orchestration_name = orchestration_name self._logger.debug( f"{instance_id}: Beginning replay for orchestrator {orchestration_name}..." @@ -1371,7 +1519,8 @@ def execute( f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}" ) return ExecutionResults( - actions=actions, encoded_custom_status=ctx._encoded_custom_status + actions=actions, encoded_custom_status=ctx._encoded_custom_status, + orchestration_trace_context=ctx._orchestration_trace_context, ) def process_event( @@ -1382,12 +1531,10 @@ def process_event( self._suspended_events.append(event) return - # CONSIDER: change to a switch statement with event.WhichOneof("eventType") try: if event.HasField("orchestratorStarted"): ctx.current_utc_datetime = event.timestamp.ToDatetime() elif event.HasField("executionStarted"): - # TODO: Check if we already started the orchestration fn = self._registry.get_orchestrator(event.executionStarted.name) if fn is None: raise OrchestratorNotRegisteredError( @@ -1397,6 +1544,21 @@ def process_event( if event.executionStarted.version: ctx._version = event.executionStarted.version.value + # Store the parent trace context for propagation to child tasks + if event.executionStarted.HasField("parentTraceContext"): + ctx._parent_trace_context = event.executionStarted.parentTraceContext + # Reuse a persisted span ID from a prior dispatch so + # activities/timers/sub-orchestrations across all + # dispatches share the same parent. On the first + # dispatch, generate a new random span ID. + if self._persisted_orch_span_id: + ctx._orchestration_trace_context = tracing.reconstruct_trace_context( + ctx._parent_trace_context, + self._persisted_orch_span_id) + else: + ctx._orchestration_trace_context = tracing.generate_client_trace_context( + parent_trace_context=ctx._parent_trace_context) + if self._registry.versioning: version_failure = self.evaluate_orchestration_versioning( self._registry.versioning, @@ -1440,16 +1602,34 @@ def process_event( raise _get_wrong_action_type_error( timer_id, expected_method_name, action ) + # Track timer fire_at and creation timestamp for span emission + if action.createTimer.HasField("fireAt"): + created_ns = (event.timestamp.ToNanoseconds() + if event.HasField("timestamp") else None) + self._timer_fire_at[timer_id] = ( + action.createTimer.fireAt.ToDatetime(), created_ns, + ) elif event.HasField("timerFired"): timer_id = event.timerFired.timerId timer_task = ctx._pending_tasks.pop(timer_id, None) if not timer_task: - # TODO: Should this be an error? When would it ever happen? - if not ctx._is_replaying: + # Unexpected event for unknown timer; log and skip. + if not ctx.is_replaying: self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}." ) return + # Emit timer span with backdated start time (skip during replay) + if not ctx.is_replaying: + timer_info = self._timer_fire_at.get(timer_id) + if timer_info is not None: + fire_at, created_ns = timer_info + tracing.emit_timer_span( + self._orchestration_name, ctx.instance_id, + timer_id, fire_at, + scheduled_time_ns=created_ns, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + ) timer_task.complete(None) if timer_task._retryable_parent is not None: activity_action = timer_task._retryable_parent._action @@ -1493,17 +1673,39 @@ def process_event( expected_task_name=event.taskScheduled.name, actual_task_name=action.scheduleTask.name, ) + # Store info for deferred CLIENT span reconstruction + ts_evt = event.taskScheduled + if ts_evt.HasField("parentTraceContext") and ts_evt.parentTraceContext.traceParent: + sched_ns = event.timestamp.ToNanoseconds() if event.HasField("timestamp") else None + ver_str = ts_evt.version.value if ts_evt.HasField("version") else None + self._task_scheduled_info[task_id] = ( + "activity", ts_evt.name, ctx.instance_id, + sched_ns, ts_evt.parentTraceContext, ver_str, + ) elif event.HasField("taskCompleted"): # This history event contains the result of a completed activity task. task_id = event.taskCompleted.taskScheduledId activity_task = ctx._pending_tasks.pop(task_id, None) if not activity_task: - # TODO: Should this be an error? When would it ever happen? + # Unexpected completion for unknown task; log and skip. if not ctx.is_replaying: self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}." ) return + # Emit deferred CLIENT span with proper timestamps + if not ctx.is_replaying: + info = self._task_scheduled_info.pop(task_id, None) + if info is not None: + t_type, t_name, t_iid, s_ns, c_ctx, t_ver = info + e_ns = event.timestamp.ToNanoseconds() if event.HasField("timestamp") else None + tracing.emit_client_span( + t_type, t_name, t_iid, task_id, + client_trace_context=c_ctx, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + start_time_ns=s_ns, end_time_ns=e_ns, + version=t_ver, + ) result = None if not ph.is_empty(event.taskCompleted.result): result = shared.from_json(event.taskCompleted.result.value) @@ -1513,13 +1715,29 @@ def process_event( task_id = event.taskFailed.taskScheduledId activity_task = ctx._pending_tasks.pop(task_id, None) if not activity_task: - # TODO: Should this be an error? When would it ever happen? + # Unexpected failure for unknown task; log and skip. if not ctx.is_replaying: self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}." ) return + # Emit deferred CLIENT span with error status + if not ctx.is_replaying: + info = self._task_scheduled_info.pop(task_id, None) + if info is not None: + t_type, t_name, t_iid, s_ns, c_ctx, t_ver = info + e_ns = event.timestamp.ToNanoseconds() if event.HasField("timestamp") else None + tracing.emit_client_span( + t_type, t_name, t_iid, task_id, + client_trace_context=c_ctx, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + start_time_ns=s_ns, end_time_ns=e_ns, + is_error=True, + error_message=str(event.taskFailed.failureDetails.errorMessage), + version=t_ver, + ) + if isinstance(activity_task, task.RetryableTask): if activity_task._retry_policy is not None: next_delay = activity_task.compute_next_delay() @@ -1563,16 +1781,38 @@ def process_event( expected_task_name=event.subOrchestrationInstanceCreated.name, actual_task_name=action.createSubOrchestration.name, ) + # Store info for deferred CLIENT span reconstruction + sub_evt = event.subOrchestrationInstanceCreated + if sub_evt.HasField("parentTraceContext") and sub_evt.parentTraceContext.traceParent: + sched_ns = event.timestamp.ToNanoseconds() if event.HasField("timestamp") else None + ver_str = sub_evt.version.value if sub_evt.HasField("version") else None + self._task_scheduled_info[task_id] = ( + "orchestration", sub_evt.name, sub_evt.instanceId, + sched_ns, sub_evt.parentTraceContext, ver_str, + ) elif event.HasField("subOrchestrationInstanceCompleted"): task_id = event.subOrchestrationInstanceCompleted.taskScheduledId sub_orch_task = ctx._pending_tasks.pop(task_id, None) if not sub_orch_task: - # TODO: Should this be an error? When would it ever happen? + # Unexpected completion for unknown sub-orchestration; log and skip. if not ctx.is_replaying: self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}." ) return + # Emit deferred CLIENT span with proper timestamps + if not ctx.is_replaying: + info = self._task_scheduled_info.pop(task_id, None) + if info is not None: + t_type, t_name, t_iid, s_ns, c_ctx, t_ver = info + e_ns = event.timestamp.ToNanoseconds() if event.HasField("timestamp") else None + tracing.emit_client_span( + t_type, t_name, t_iid, task_id, + client_trace_context=c_ctx, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + start_time_ns=s_ns, end_time_ns=e_ns, + version=t_ver, + ) result = None if not ph.is_empty(event.subOrchestrationInstanceCompleted.result): result = shared.from_json( @@ -1585,12 +1825,27 @@ def process_event( task_id = failedEvent.taskScheduledId sub_orch_task = ctx._pending_tasks.pop(task_id, None) if not sub_orch_task: - # TODO: Should this be an error? When would it ever happen? + # Unexpected failure for unknown sub-orchestration; log and skip. if not ctx.is_replaying: self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}." ) return + # Emit deferred CLIENT span with error status + if not ctx.is_replaying: + info = self._task_scheduled_info.pop(task_id, None) + if info is not None: + t_type, t_name, t_iid, s_ns, c_ctx, t_ver = info + e_ns = event.timestamp.ToNanoseconds() if event.HasField("timestamp") else None + tracing.emit_client_span( + t_type, t_name, t_iid, task_id, + client_trace_context=c_ctx, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + start_time_ns=s_ns, end_time_ns=e_ns, + is_error=True, + error_message=str(failedEvent.failureDetails.errorMessage), + version=t_ver, + ) if isinstance(sub_orch_task, task.RetryableTask): if sub_orch_task._retry_policy is not None: next_delay = sub_orch_task.compute_next_delay() @@ -1733,7 +1988,7 @@ def process_event( section_id = event.entityLockGranted.criticalSectionId task_id = ctx._entity_lock_id_map.pop(section_id, None) if not task_id: - # TODO: Should this be an error? When would it ever happen? + # Unexpected lock grant for unknown section; log and skip. if not ctx.is_replaying: self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'." diff --git a/examples/distributed-tracing/README.md b/examples/distributed-tracing/README.md new file mode 100644 index 0000000..d47a045 --- /dev/null +++ b/examples/distributed-tracing/README.md @@ -0,0 +1,186 @@ +# Distributed Tracing Example + +This example demonstrates how to set up **distributed tracing** with the +Durable Task Python SDK using [OpenTelemetry](https://opentelemetry.io/) +and [Jaeger](https://www.jaegertracing.io/) as the trace backend. + +The sample orchestration showcases three key Durable Task features that +all produce correlated trace spans: + +1. **Timers** — a short delay before starting work. +1. **Sub-orchestration** — delegates city-level weather collection to a + child orchestration. +1. **Activities** — individual activity calls to fetch weather data and + produce a summary. + +## Prerequisites + +- [Docker](https://www.docker.com/) (for the emulator and Jaeger) +- Python 3.10+ + +## Quick Start + +### 1. Start the DTS Emulator + +```bash +docker run --name dtsemulator -d -p 8080:8080 mcr.microsoft.com/dts/dts-emulator:latest +``` + +### 2. Start Jaeger + +Jaeger's all-in-one image accepts OTLP over gRPC on port **4317** and +serves the UI on port **16686**: + +```bash +docker run --name jaeger -d \ + -p 4317:4317 \ + -p 16686:16686 \ + jaegertracing/all-in-one:latest +``` + +PowerShell: + +```powershell +docker run --name jaeger -d ` + -p 4317:4317 ` + -p 16686:16686 ` + jaegertracing/all-in-one:latest +``` + +### 3. Install Dependencies + +Create and activate a virtual environment, then install the required +packages: + +```bash +python -m venv .venv +``` + +Bash: + +```bash +source .venv/bin/activate +``` + +PowerShell: + +```powershell +.\.venv\Scripts\Activate.ps1 +``` + +Install requirements: + +```bash +pip install -r requirements.txt +``` + +If you are running from a local clone of the repository, install the +local packages in editable mode instead (run from the repo root): + +```bash +pip install -e ".[opentelemetry]" -e ./durabletask-azuremanaged +``` + +### 4. Run the Example + +```bash +python app.py +``` + +Once the orchestration completes, open the Jaeger UI at +, select the **durabletask-tracing-example** +service, and click **Find Traces** to explore the spans. + +## What You Will See in Jaeger + +A single trace for the orchestration will contain spans for: + +- **`orchestration:weather_report_orchestrator`** — the top-level + orchestration span. +- **`timer`** — the 2-second timer delay. +- **`orchestration:collect_weather`** — the sub-orchestration span. +- **`activity:get_weather`** — one span per city + (Tokyo, Seattle, London). +- **`activity:summarize`** — the final summarization activity. + +All spans share the same trace ID, so you can follow the full execution +flow from the parent orchestration through the sub-orchestration and +into each activity. + +## Configuration + +The example reads the following environment variables (all optional): + +| Variable | Default | Description | +|---|---|---| +| `ENDPOINT` | `http://localhost:8080` | DTS emulator / scheduler endpoint | +| `TASKHUB` | `default` | Task hub name | +| `OTEL_EXPORTER_OTLP_ENDPOINT` | `http://localhost:4317` | OTLP gRPC endpoint (Jaeger) | + +## Important Usage Guidelines for Distributed Tracing + +### Install the OpenTelemetry extras + +The SDK ships OpenTelemetry as an **optional** dependency. Install it +with the `opentelemetry` extra: + +```bash +pip install "durabletask[opentelemetry]" +``` + +Without these packages the SDK still works, but no trace spans are +emitted. + +### Configure the `TracerProvider` before starting the worker + +OpenTelemetry requires a configured `TracerProvider` with at least one +`SpanProcessor` and exporter **before** any spans are created. In +practice this means setting it up at the top of your entry-point module, +before constructing the worker or client: + +```python +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.resources import Resource +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + +resource = Resource.create({"service.name": "my-app"}) +provider = TracerProvider(resource=resource) +provider.add_span_processor( + BatchSpanProcessor(OTLPSpanExporter(endpoint="http://localhost:4317", insecure=True)) +) +trace.set_tracer_provider(provider) +``` + +### Flush spans before exiting + +The `BatchSpanProcessor` buffers spans and exports them in the +background. If the process exits before the buffer is flushed, some +spans may be lost. Call `provider.force_flush()` (and optionally add a +short sleep) before your program terminates: + +```python +provider.force_flush() +``` + +### Orchestrator code must remain deterministic + +Distributed tracing does **not** change the determinism requirement for +orchestrator functions. Do not create your own OpenTelemetry spans +inside orchestrator code — the SDK handles span creation automatically. +Activity functions and client code are free to create additional spans +as needed. + +### Use `BatchSpanProcessor` in production + +`SimpleSpanProcessor` exports every span synchronously, which adds +latency to every operation. Use `BatchSpanProcessor` for production +workloads to avoid performance overhead. + +### Choose the right exporter for your backend + +This example uses the OTLP/gRPC exporter, which is compatible with +Jaeger 1.35+, the OpenTelemetry Collector, Azure Monitor (via the +Azure Monitor OpenTelemetry exporter), and many other backends. Swap +the exporter if your tracing backend uses a different protocol. diff --git a/examples/distributed-tracing/app.py b/examples/distributed-tracing/app.py new file mode 100644 index 0000000..8cb09b8 --- /dev/null +++ b/examples/distributed-tracing/app.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Distributed tracing example using OpenTelemetry and Jaeger. + +This example demonstrates how to configure OpenTelemetry distributed tracing +with the Durable Task Python SDK. The orchestration showcases timers, +activities, and a sub-orchestration, all producing correlated trace spans +visible in the Jaeger UI. + +Prerequisites: + - DTS emulator running on localhost:8080 + - Jaeger running on localhost:4317 (OTLP gRPC) / localhost:16686 (UI) + - pip install -r requirements.txt +""" + +import os +import time +from datetime import timedelta + +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +# --------------------------------------------------------------------------- +# OpenTelemetry configuration — MUST be done before any spans are created +# --------------------------------------------------------------------------- + +OTEL_ENDPOINT = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + +resource = Resource.create({"service.name": "durabletask-tracing-example"}) +provider = TracerProvider(resource=resource) +provider.add_span_processor( + BatchSpanProcessor( + OTLPSpanExporter(endpoint=OTEL_ENDPOINT, insecure=True) + ) +) +trace.set_tracer_provider(provider) + + +# --------------------------------------------------------------------------- +# Activity functions +# --------------------------------------------------------------------------- + +def get_weather(ctx: task.ActivityContext, city: str) -> str: + """Simulate fetching weather data for a city.""" + # In a real app this would call an external API + weather_data = { + "Tokyo": "Sunny, 22°C", + "Seattle": "Rainy, 12°C", + "London": "Cloudy, 15°C", + } + result = weather_data.get(city, "Unknown") + print(f" [Activity] get_weather({city}) -> {result}") + return result + + +def summarize(ctx: task.ActivityContext, reports: list) -> str: + """Combine individual weather reports into a summary string.""" + summary = " | ".join(reports) + print(f" [Activity] summarize -> {summary}") + return summary + + +# --------------------------------------------------------------------------- +# Sub-orchestration +# --------------------------------------------------------------------------- + +def collect_weather(ctx: task.OrchestrationContext, cities: list): + """Sub-orchestration that collects weather for a list of cities.""" + results = [] + for city in cities: + weather = yield ctx.call_activity(get_weather, input=city) + results.append(f"{city}: {weather}") + return results + + +# --------------------------------------------------------------------------- +# Main orchestration +# --------------------------------------------------------------------------- + +def weather_report_orchestrator(ctx: task.OrchestrationContext, cities: list): + """Top-level orchestration demonstrating timers, activities, and sub-orchestrations. + + Flow: + 1. Wait for a short timer (simulating a scheduled delay). + 2. Call a sub-orchestration to collect weather data for each city. + 3. Call an activity to summarize the results. + """ + # Step 1 — Timer: wait briefly before starting work + yield ctx.create_timer(timedelta(milliseconds=100)) + if not ctx.is_replaying: + print(" [Orchestrator] Timer fired — starting weather collection") + + # Step 2 — Sub-orchestration: delegate city-level work + reports = yield ctx.call_sub_orchestrator(collect_weather, input=cities) + + # Step 3 — Activity: summarize the collected reports + summary = yield ctx.call_activity(summarize, input=reports) + + return summary + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + # Use environment variables if provided, otherwise use default emulator values + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + print(f"Using taskhub: {taskhub_name}") + print(f"Using endpoint: {endpoint}") + print(f"OTLP endpoint: {OTEL_ENDPOINT}") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + secure_channel = endpoint.startswith("https://") + credential = DefaultAzureCredential() if secure_channel else None + + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + ) as w: + # Register orchestrators and activities + w.add_orchestrator(weather_report_orchestrator) + w.add_orchestrator(collect_weather) + w.add_activity(get_weather) + w.add_activity(summarize) + w.start() + print("Worker started.") + + # Create client, schedule the orchestration, and wait for completion + c = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + ) + + cities = ["Tokyo", "Seattle", "London"] + instance_id = c.schedule_new_orchestration( + weather_report_orchestrator, input=cities, + ) + print(f"Orchestration started: {instance_id}") + + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f"Orchestration completed! Result: {state.serialized_output}") + elif state: + print(f"Orchestration failed: {state.failure_details}") + + # Flush any remaining spans to the exporter + provider.force_flush() + time.sleep(1) + + print("Done. Open Jaeger at http://localhost:16686 to view traces.") diff --git a/examples/distributed-tracing/images/dts-dashboard-completed.png b/examples/distributed-tracing/images/dts-dashboard-completed.png new file mode 100644 index 0000000..6ad28aa Binary files /dev/null and b/examples/distributed-tracing/images/dts-dashboard-completed.png differ diff --git a/examples/distributed-tracing/images/jaeger-full-trace-detail.png b/examples/distributed-tracing/images/jaeger-full-trace-detail.png new file mode 100644 index 0000000..f009184 Binary files /dev/null and b/examples/distributed-tracing/images/jaeger-full-trace-detail.png differ diff --git a/examples/distributed-tracing/images/jaeger-span-detail.png b/examples/distributed-tracing/images/jaeger-span-detail.png new file mode 100644 index 0000000..c6be960 Binary files /dev/null and b/examples/distributed-tracing/images/jaeger-span-detail.png differ diff --git a/examples/distributed-tracing/requirements.txt b/examples/distributed-tracing/requirements.txt new file mode 100644 index 0000000..12ac909 --- /dev/null +++ b/examples/distributed-tracing/requirements.txt @@ -0,0 +1,4 @@ +durabletask[opentelemetry] +durabletask-azuremanaged +azure-identity +opentelemetry-exporter-otlp-proto-grpc diff --git a/pyproject.toml b/pyproject.toml index b9a8287..04cfcc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,12 @@ dependencies = [ "packaging" ] +[project.optional-dependencies] +opentelemetry = [ + "opentelemetry-api>=1.0.0", + "opentelemetry-sdk>=1.0.0" +] + [project.urls] repository = "https://github.com/microsoft/durabletask-python" changelog = "https://github.com/microsoft/durabletask-python/blob/main/CHANGELOG.md" diff --git a/requirements.txt b/requirements.txt index 166d047..ee1cad9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,6 @@ pytest-asyncio pytest-cov azure-identity asyncio -packaging \ No newline at end of file +packaging +opentelemetry-api +opentelemetry-sdk \ No newline at end of file diff --git a/tests/durabletask/test_tracing.py b/tests/durabletask/test_tracing.py new file mode 100644 index 0000000..9969afa --- /dev/null +++ b/tests/durabletask/test_tracing.py @@ -0,0 +1,1907 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for distributed tracing utilities and integration.""" + +import json +import logging +from datetime import datetime, timezone +from unittest.mock import patch + +import pytest +from google.protobuf import timestamp_pb2, wrappers_pb2 + +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import StatusCode + +import durabletask.internal.helpers as helpers +import durabletask.internal.orchestrator_service_pb2 as pb +import durabletask.internal.tracing as tracing +from durabletask import task, worker + +logging.basicConfig( + format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.DEBUG) +TEST_LOGGER = logging.getLogger("tests") +TEST_INSTANCE_ID = 'abc123' + + +# Module-level setup: create a TracerProvider with an InMemorySpanExporter once. +# Newer OpenTelemetry versions only allow set_tracer_provider to be called once. +_EXPORTER = InMemorySpanExporter() +_PROVIDER = TracerProvider() +_PROVIDER.add_span_processor(SimpleSpanProcessor(_EXPORTER)) +trace.set_tracer_provider(_PROVIDER) + + +@pytest.fixture(autouse=True) +def otel_setup(): + """Clear the in-memory exporter before each test.""" + _EXPORTER.clear() + yield _EXPORTER + + +# --------------------------------------------------------------------------- +# Shared test constants and helpers +# --------------------------------------------------------------------------- + +_SAMPLE_TRACE_ID = "0af7651916cd43dd8448eb211c80319c" +_SAMPLE_PARENT_SPAN_ID = "b7ad6b7169203331" +_SAMPLE_CLIENT_SPAN_ID = "00f067aa0ba902b7" + + +def _make_parent_trace_ctx(): + return pb.TraceContext( + traceParent=f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_PARENT_SPAN_ID}-01", + spanID=_SAMPLE_PARENT_SPAN_ID, + ) + + +def _make_client_trace_ctx(): + return pb.TraceContext( + traceParent=f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_CLIENT_SPAN_ID}-01", + spanID=_SAMPLE_CLIENT_SPAN_ID, + ) + + +# --------------------------------------------------------------------------- +# Tests for tracing utility functions +# --------------------------------------------------------------------------- + + +class TestGetCurrentTraceContext: + """Tests for tracing.get_current_trace_context().""" + + def test_returns_none_when_no_active_span(self, otel_setup): + """When there is no active span, should return None.""" + result = tracing.get_current_trace_context() + assert result is None + + def test_returns_trace_context_with_active_span(self, otel_setup): + """When there is an active span, should return a populated TraceContext.""" + tracer = trace.get_tracer("test") + with tracer.start_as_current_span("test-span"): + result = tracing.get_current_trace_context() + + assert result is not None + assert isinstance(result, pb.TraceContext) + assert result.traceParent != "" + assert result.spanID != "" + # traceparent format: 00--- + parts = result.traceParent.split("-") + assert len(parts) == 4 + assert parts[0] == "00" + assert len(parts[1]) == 32 # trace ID + assert len(parts[2]) == 16 # span ID + assert result.spanID == parts[2] + + +class TestExtractTraceContext: + """Tests for tracing.extract_trace_context().""" + + def test_returns_none_for_none_input(self): + result = tracing.extract_trace_context(None) + assert result is None + + def test_returns_none_for_empty_traceparent(self): + proto_ctx = pb.TraceContext(traceParent="", spanID="") + result = tracing.extract_trace_context(proto_ctx) + assert result is None + + def test_extracts_valid_context(self, otel_setup): + """Should extract a valid OTel context from a protobuf TraceContext.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + proto_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + otel_ctx = tracing.extract_trace_context(proto_ctx) + assert otel_ctx is not None + + def test_extracts_context_with_tracestate(self, otel_setup): + """Should extract context including tracestate.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + tracestate_val = "congo=t61rcWkgMzE" + proto_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + traceState=wrappers_pb2.StringValue(value=tracestate_val), + ) + otel_ctx = tracing.extract_trace_context(proto_ctx) + assert otel_ctx is not None + + +class TestStartSpan: + """Tests for tracing.start_span().""" + + def test_creates_span_without_parent(self, otel_setup: InMemorySpanExporter): + """Should create a span even without a parent trace context.""" + with tracing.start_span("test-span") as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "test-span" + + def test_creates_span_with_attributes(self, otel_setup: InMemorySpanExporter): + """Should create a span with custom attributes.""" + attrs = {"key1": "value1", "key2": "value2"} + with tracing.start_span("test-span", attributes=attrs) as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].attributes is not None + assert spans[0].attributes["key1"] == "value1" + assert spans[0].attributes["key2"] == "value2" + + def test_creates_child_span_from_trace_context(self, otel_setup: InMemorySpanExporter): + """Should create a child span linked to the parent trace context.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + proto_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + with tracing.start_span("child-span", trace_context=proto_ctx) as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + child_span = spans[0] + assert child_span.name == "child-span" + # The child span's trace ID should match the parent's + assert child_span.context is not None + assert child_span.context.trace_id == int("0af7651916cd43dd8448eb211c80319c", 16) + + +class TestSetSpanError: + """Tests for tracing.set_span_error().""" + + def test_records_error_on_span(self, otel_setup: InMemorySpanExporter): + """Should record error status and exception on the span.""" + with tracing.start_span("error-span") as span: + ex = ValueError("something went wrong") + tracing.set_span_error(span, ex) + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert spans[0].status.description is not None + assert "something went wrong" in spans[0].status.description + + def test_noop_with_none_span(self): + """Should not raise when span is None.""" + tracing.set_span_error(None, ValueError("test")) + + +# --------------------------------------------------------------------------- +# Tests for orchestration trace context propagation +# --------------------------------------------------------------------------- + + +class TestOrchestrationTraceContextPropagation: + """Tests that orchestration actions include trace context.""" + + def test_schedule_task_action_includes_trace_context(self): + """new_schedule_task_action should include parentTraceContext when provided.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + action = helpers.new_schedule_task_action( + 0, "my_activity", None, None, + parent_trace_context=parent_ctx + ) + assert action.scheduleTask.parentTraceContext.traceParent == traceparent + + def test_schedule_task_action_without_trace_context(self): + """new_schedule_task_action should work without trace context.""" + action = helpers.new_schedule_task_action(0, "my_activity", None, None) + # parentTraceContext should not be set (default empty) + assert action.scheduleTask.parentTraceContext.traceParent == "" + + def test_create_sub_orchestration_action_includes_trace_context(self): + """new_create_sub_orchestration_action should include parentTraceContext.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + action = helpers.new_create_sub_orchestration_action( + 0, "sub_orch", "inst1", None, None, + parent_trace_context=parent_ctx + ) + assert action.createSubOrchestration.parentTraceContext.traceParent == traceparent + + def test_create_sub_orchestration_action_without_trace_context(self): + """new_create_sub_orchestration_action should work without trace context.""" + action = helpers.new_create_sub_orchestration_action( + 0, "sub_orch", "inst1", None, None + ) + assert action.createSubOrchestration.parentTraceContext.traceParent == "" + + +class TestOrchestrationExecutorStoresTraceContext: + """Tests that the orchestration executor extracts and stores trace context from events.""" + + def test_execution_started_stores_parent_trace_context(self): + """process_event should store parentTraceContext from executionStarted.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + + def simple_orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + registry = worker._Registry() + registry.add_orchestrator(simple_orchestrator) + + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + assert ctx._parent_trace_context is None + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + + # Create an executionStarted event with parentTraceContext + event = pb.HistoryEvent( + eventId=-1, + executionStarted=pb.ExecutionStartedEvent( + name="simple_orchestrator", + orchestrationInstance=pb.OrchestrationInstance(instanceId=TEST_INSTANCE_ID), + parentTraceContext=parent_ctx, + ) + ) + + executor.process_event(ctx, event) + assert ctx._parent_trace_context is not None + assert ctx._parent_trace_context.traceParent == traceparent + + def test_execution_started_without_trace_context(self): + """process_event should leave parentTraceContext as None when not provided.""" + def simple_orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + registry = worker._Registry() + registry.add_orchestrator(simple_orchestrator) + + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + + event = pb.HistoryEvent( + eventId=-1, + executionStarted=pb.ExecutionStartedEvent( + name="simple_orchestrator", + orchestrationInstance=pb.OrchestrationInstance(instanceId=TEST_INSTANCE_ID), + ) + ) + + executor.process_event(ctx, event) + assert ctx._parent_trace_context is None + assert ctx._orchestration_trace_context is None + + def test_execution_started_generates_orchestration_trace_context(self): + """process_event should generate _orchestration_trace_context + from the parent trace context for use as child span parent.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + + def simple_orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + registry = worker._Registry() + registry.add_orchestrator(simple_orchestrator) + + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + assert ctx._orchestration_trace_context is None + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + event = pb.HistoryEvent( + eventId=-1, + executionStarted=pb.ExecutionStartedEvent( + name="simple_orchestrator", + orchestrationInstance=pb.OrchestrationInstance(instanceId=TEST_INSTANCE_ID), + parentTraceContext=parent_ctx, + ) + ) + + executor.process_event(ctx, event) + orch_ctx = ctx._orchestration_trace_context + assert orch_ctx is not None + # The orchestration context should have the same trace ID + # but a different span ID (pre-determined for the SERVER span) + parts = orch_ctx.traceParent.split("-") + assert parts[1] == "0af7651916cd43dd8448eb211c80319c" + assert parts[2] != "b7ad6b7169203331" + assert orch_ctx.spanID == parts[2] + + def test_execution_started_reuses_persisted_span_id(self): + """When a persisted orchestration span ID exists from a prior dispatch, + process_event should reuse that span ID instead of generating a new one.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + persisted_span_id = "00f067aa0ba902b7" + + def simple_orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + registry = worker._Registry() + registry.add_orchestrator(simple_orchestrator) + + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + executor = worker._OrchestrationExecutor( + registry, TEST_LOGGER, + persisted_orch_span_id=persisted_span_id, + ) + + event = pb.HistoryEvent( + eventId=-1, + executionStarted=pb.ExecutionStartedEvent( + name="simple_orchestrator", + orchestrationInstance=pb.OrchestrationInstance(instanceId=TEST_INSTANCE_ID), + parentTraceContext=parent_ctx, + ) + ) + + executor.process_event(ctx, event) + orch_ctx = ctx._orchestration_trace_context + assert orch_ctx is not None + # Should reuse the persisted span ID + parts = orch_ctx.traceParent.split("-") + assert parts[1] == "0af7651916cd43dd8448eb211c80319c" # same trace ID + assert parts[2] == persisted_span_id # reused span ID + assert orch_ctx.spanID == persisted_span_id + + +class TestOtelNotAvailable: + """Tests that tracing functions gracefully degrade when OTel is unavailable.""" + + def test_get_current_trace_context_without_otel(self): + """get_current_trace_context returns None when OTel is not available.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + result = tracing.get_current_trace_context() + assert result is None + + def test_extract_trace_context_without_otel(self): + """extract_trace_context returns None when OTel is not available.""" + proto_ctx = pb.TraceContext(traceParent="00-abc-def-01", spanID="def") + with patch.object(tracing, '_OTEL_AVAILABLE', False): + result = tracing.extract_trace_context(proto_ctx) + assert result is None + + def test_start_span_without_otel(self): + """start_span should yield None when OTel is not available.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + with tracing.start_span("test") as span: + assert span is None + + def test_set_span_error_without_otel(self): + """set_span_error should be a no-op when OTel is not available.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + tracing.set_span_error(None, ValueError("test")) # should not raise + + def test_emit_orchestration_span_without_otel(self): + """emit_orchestration_span is a no-op when OTel is unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + tracing.emit_orchestration_span( + "test_orch", "inst1", 1000, False) + + def test_emit_timer_span_without_otel(self): + """emit_timer_span is a no-op when OTel is unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + tracing.emit_timer_span("orch", "inst1", 1, datetime.now(timezone.utc)) + + def test_start_create_orchestration_span_without_otel(self): + """start_create_orchestration_span yields None when OTel unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + with tracing.start_create_orchestration_span("orch", "inst1") as span: + assert span is None + + def test_start_raise_event_span_without_otel(self): + """start_raise_event_span yields None when OTel unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + with tracing.start_raise_event_span("evt", "inst1") as span: + assert span is None + + def test_emit_event_raised_span_without_otel(self): + """emit_event_raised_span is a no-op when OTel is unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + tracing.emit_event_raised_span("evt", "inst1") + + +# --------------------------------------------------------------------------- +# Tests for span naming helpers +# --------------------------------------------------------------------------- + + +class TestSpanNaming: + """Tests for create_span_name and create_timer_span_name.""" + + def test_create_span_name_without_version(self): + assert tracing.create_span_name("orchestration", "MyOrch") == "orchestration:MyOrch" + + def test_create_span_name_with_version(self): + assert tracing.create_span_name("activity", "Say", "1.0") == "activity:Say@(1.0)" + + def test_create_timer_span_name(self): + assert tracing.create_timer_span_name("MyOrch") == "orchestration:MyOrch:timer" + + +# --------------------------------------------------------------------------- +# Tests for schema attribute constants +# --------------------------------------------------------------------------- + + +class TestSchemaConstants: + """Tests that schema constants match expected names.""" + + def test_attribute_keys_defined(self): + assert tracing.ATTR_TASK_TYPE == "durabletask.type" + assert tracing.ATTR_TASK_NAME == "durabletask.task.name" + assert tracing.ATTR_TASK_VERSION == "durabletask.task.version" + assert tracing.ATTR_TASK_INSTANCE_ID == "durabletask.task.instance_id" + assert tracing.ATTR_TASK_STATUS == "durabletask.task.status" + assert tracing.ATTR_TASK_TASK_ID == "durabletask.task.task_id" + assert tracing.ATTR_EVENT_TARGET_INSTANCE_ID == "durabletask.event.target_instance_id" + assert tracing.ATTR_FIRE_AT == "durabletask.fire_at" + + +# --------------------------------------------------------------------------- +# Tests for Producer / Client / Server span creation +# --------------------------------------------------------------------------- + + +class TestCreateOrchestrationSpan: + """Tests for start_create_orchestration_span (Producer span).""" + + def test_creates_producer_span(self, otel_setup: InMemorySpanExporter): + """Should create a Producer span for create_orchestration.""" + with tracing.start_create_orchestration_span("MyOrch", "inst-123") as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "create_orchestration:MyOrch" + assert s.kind == trace.SpanKind.PRODUCER + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "orchestration" + assert s.attributes[tracing.ATTR_TASK_NAME] == "MyOrch" + assert s.attributes[tracing.ATTR_TASK_INSTANCE_ID] == "inst-123" + + def test_creates_producer_span_with_version(self, otel_setup: InMemorySpanExporter): + with tracing.start_create_orchestration_span("MyOrch", "inst-123", version="2.0"): + pass + + spans = otel_setup.get_finished_spans() + assert spans[0].name == "create_orchestration:MyOrch@(2.0)" + assert spans[0].attributes is not None + assert spans[0].attributes[tracing.ATTR_TASK_VERSION] == "2.0" + + def test_trace_context_injected_inside_producer_span(self, otel_setup: InMemorySpanExporter): + """Inside the producer span, get_current_trace_context should capture producer span ctx.""" + with tracing.start_create_orchestration_span("Orch", "inst"): + ctx = tracing.get_current_trace_context() + assert ctx is not None + assert ctx.traceParent != "" + + +class TestRaiseEventSpan: + """Tests for start_raise_event_span (Producer span).""" + + def test_creates_producer_span(self, otel_setup: InMemorySpanExporter): + with tracing.start_raise_event_span("MyEvent", "inst-456") as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "orchestration_event:MyEvent" + assert s.kind == trace.SpanKind.PRODUCER + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "event" + assert s.attributes[tracing.ATTR_TASK_NAME] == "MyEvent" + assert s.attributes[tracing.ATTR_EVENT_TARGET_INSTANCE_ID] == "inst-456" + + +class TestOrchestrationServerSpan: + """Tests for emit_orchestration_span.""" + + def test_emits_server_span(self, otel_setup: InMemorySpanExporter): + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + import time + start_ns = time.time_ns() + tracing.emit_orchestration_span( + "MyOrch", "inst-100", start_ns, False, + parent_trace_context=parent_ctx, + ) + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "orchestration:MyOrch" + assert s.kind == trace.SpanKind.SERVER + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "orchestration" + assert s.attributes[tracing.ATTR_TASK_NAME] == "MyOrch" + assert s.attributes[tracing.ATTR_TASK_INSTANCE_ID] == "inst-100" + assert s.attributes[tracing.ATTR_TASK_STATUS] == "Completed" + + def test_server_span_failure(self, otel_setup: InMemorySpanExporter): + import time + tracing.emit_orchestration_span( + "FailOrch", "inst-200", time.time_ns(), True, "boom", + ) + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert spans[0].attributes is not None + assert spans[0].attributes[tracing.ATTR_TASK_STATUS] == "Failed" + + def test_server_span_backdated(self, otel_setup: InMemorySpanExporter): + """The span start time should honour the provided start_time_ns.""" + start_ns = 1704067200000000000 # 2024-01-01T00:00:00Z + tracing.emit_orchestration_span( + "BackdatedOrch", "inst-300", start_ns, False, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].start_time == start_ns + + def test_deferred_span_uses_predetermined_span_id(self, otel_setup: InMemorySpanExporter): + """When orchestration_trace_context is provided, the SERVER span + should use the pre-determined span ID from that context.""" + parent_tp = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=parent_tp, spanID="b7ad6b7169203331", + ) + orch_span_id = "00f067aa0ba902b7" + orch_tp = f"00-0af7651916cd43dd8448eb211c80319c-{orch_span_id}-01" + orch_ctx = pb.TraceContext( + traceParent=orch_tp, spanID=orch_span_id, + ) + import time + start_ns = time.time_ns() + tracing.emit_orchestration_span( + "DeferredOrch", "inst-400", start_ns, False, + parent_trace_context=parent_ctx, + orchestration_trace_context=orch_ctx, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "orchestration:DeferredOrch" + assert s.kind == trace.SpanKind.SERVER + # Span ID should match the pre-determined orchestration context + assert s.context is not None + assert s.context.span_id == int(orch_span_id, 16) + # Parent should be the PRODUCER span + assert s.parent is not None + assert s.parent.span_id == int("b7ad6b7169203331", 16) + + def test_deferred_span_with_failure(self, otel_setup: InMemorySpanExporter): + """When orchestration_trace_context is provided and is_failed=True, + the deferred span should have ERROR status.""" + parent_tp = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=parent_tp, spanID="b7ad6b7169203331", + ) + orch_span_id = "00f067aa0ba902b7" + orch_tp = f"00-0af7651916cd43dd8448eb211c80319c-{orch_span_id}-01" + orch_ctx = pb.TraceContext( + traceParent=orch_tp, spanID=orch_span_id, + ) + import time + tracing.emit_orchestration_span( + "FailOrch", "inst-500", time.time_ns(), True, "kaboom", + parent_trace_context=parent_ctx, + orchestration_trace_context=orch_ctx, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.status.status_code == StatusCode.ERROR + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_STATUS] == "Failed" + assert s.context is not None + assert s.context.span_id == int(orch_span_id, 16) + + +# --------------------------------------------------------------------------- +# Tests for emit_timer_span +# --------------------------------------------------------------------------- + + +class TestTimerSpan: + """Tests for emit_timer_span.""" + + def test_emits_internal_span(self, otel_setup: InMemorySpanExporter): + fire_at = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + tracing.emit_timer_span("MyOrch", "inst-1", 5, fire_at) + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "orchestration:MyOrch:timer" + assert s.kind == trace.SpanKind.INTERNAL + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "timer" + assert s.attributes[tracing.ATTR_FIRE_AT] == fire_at.isoformat() + assert s.attributes[tracing.ATTR_TASK_TASK_ID] == "5" + + def test_backdated_start_time(self, otel_setup: InMemorySpanExporter): + """Timer span should cover the full wait period when scheduled_time_ns is set.""" + fire_at = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + created_ns = 1704067200_000_000_000 # 2024-01-01T00:00:00Z + tracing.emit_timer_span( + "MyOrch", "inst-1", 5, fire_at, scheduled_time_ns=created_ns, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].start_time == created_ns + assert spans[0].end_time is not None + assert spans[0].start_time is not None + assert spans[0].end_time > spans[0].start_time + + def test_parent_trace_context(self, otel_setup: InMemorySpanExporter): + """Timer span should be parented under the given trace context.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + fire_at = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + tracing.emit_timer_span( + "MyOrch", "inst-1", 5, fire_at, + parent_trace_context=parent_ctx, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + # The span should share the same trace ID as the parent + expected_trace_id = int("0af7651916cd43dd8448eb211c80319c", 16) + assert s.context is not None + assert s.context.trace_id == expected_trace_id + # The parent span ID should match the parent context + expected_parent_span_id = int("b7ad6b7169203331", 16) + assert s.parent is not None + assert s.parent.span_id == expected_parent_span_id + + +class TestEmitEventRaisedSpan: + """Tests for emit_event_raised_span.""" + + def test_emits_producer_span(self, otel_setup: InMemorySpanExporter): + tracing.emit_event_raised_span("approval", "inst-1", target_instance_id="inst-2") + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "orchestration_event:approval" + assert s.kind == trace.SpanKind.PRODUCER + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "event" + assert s.attributes[tracing.ATTR_EVENT_TARGET_INSTANCE_ID] == "inst-2" + + def test_emits_span_without_target(self, otel_setup: InMemorySpanExporter): + tracing.emit_event_raised_span("approval", "inst-1") + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].attributes is not None + assert tracing.ATTR_EVENT_TARGET_INSTANCE_ID not in spans[0].attributes + + def test_parent_trace_context(self, otel_setup: InMemorySpanExporter): + """Event span should be parented under the given trace context.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + tracing.emit_event_raised_span( + "approval", "inst-1", + parent_trace_context=parent_ctx, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + expected_trace_id = int("0af7651916cd43dd8448eb211c80319c", 16) + assert s.context is not None + assert s.context.trace_id == expected_trace_id + expected_parent_span_id = int("b7ad6b7169203331", 16) + assert s.parent is not None + assert s.parent.span_id == expected_parent_span_id + + +# --------------------------------------------------------------------------- +# Tests for build_orchestration_trace_context +# --------------------------------------------------------------------------- + + +class TestReconstructTraceContext: + """Tests for tracing.reconstruct_trace_context.""" + + def test_returns_context_with_given_span_id(self): + parent = pb.TraceContext( + traceParent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + spanID="b7ad6b7169203331", + ) + result = tracing.reconstruct_trace_context(parent, "00f067aa0ba902b7") + assert result is not None + parts = result.traceParent.split("-") + assert parts[1] == "0af7651916cd43dd8448eb211c80319c" # same trace ID + assert parts[2] == "00f067aa0ba902b7" # new span ID + assert result.spanID == "00f067aa0ba902b7" + + def test_preserves_trace_flags(self): + parent = pb.TraceContext( + traceParent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-00", + spanID="b7ad6b7169203331", + ) + result = tracing.reconstruct_trace_context(parent, "1234567890abcdef") + assert result is not None + parts = result.traceParent.split("-") + assert parts[3] == "00" # flags preserved + + def test_returns_none_for_invalid_parent(self): + parent = pb.TraceContext(traceParent="invalid", spanID="bad") + result = tracing.reconstruct_trace_context(parent, "1234567890abcdef") + assert result is None + + @patch.object(tracing, '_OTEL_AVAILABLE', False) + def test_returns_none_without_otel(self): + parent = pb.TraceContext( + traceParent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + spanID="b7ad6b7169203331", + ) + result = tracing.reconstruct_trace_context(parent, "00f067aa0ba902b7") + assert result is None + + +class TestBuildOrchestrationTraceContext: + """Tests for build_orchestration_trace_context.""" + + def test_returns_none_when_start_time_none(self): + result = tracing.build_orchestration_trace_context(None) + assert result is None + + def test_builds_context_with_start_time(self): + start_time_ns = 1704067200000000000 # 2024-01-01T00:00:00Z + result = tracing.build_orchestration_trace_context(start_time_ns) + assert result is not None + assert result.spanStartTime.seconds == 1704067200 + assert result.spanStartTime.nanos == 0 + + def test_builds_context_with_span_id(self): + start_time_ns = 1704067200000000000 + result = tracing.build_orchestration_trace_context( + start_time_ns, span_id="00f067aa0ba902b7") + assert result is not None + assert result.spanStartTime.seconds == 1704067200 + assert result.HasField("spanID") + assert result.spanID.value == "00f067aa0ba902b7" + + def test_builds_context_without_span_id(self): + start_time_ns = 1704067200000000000 + result = tracing.build_orchestration_trace_context(start_time_ns) + assert result is not None + assert not result.HasField("spanID") + + +class TestReplayDoesNotEmitSpans: + """Tests that replayed (old) events do NOT re-emit client spans for + activities, sub-orchestrations, or timers. CLIENT spans are deferred + until the completion event arrives as a new event; during replay + (is_replaying=True) no CLIENT spans are produced.""" + + def _get_client_spans(self, exporter): + """Return non-Server spans (Client/Internal schedule/timer spans).""" + return [ + s for s in exporter.get_finished_spans() + if s.kind != trace.SpanKind.SERVER + ] + + def test_replayed_activity_completion_no_span(self, otel_setup): + """During a replay dispatch, no CLIENT spans are emitted for + activities — both old and new completions. The CLIENT span for + activity 2 was emitted in a prior dispatch when call_activity() + was first called with is_replaying=False.""" + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + r1 = yield ctx.call_activity(dummy_activity, input=1) + r2 = yield ctx.call_activity(dummy_activity, input=2) + return [r1, r2] + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + + # First activity scheduled + completed in old_events (replay) + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + helpers.new_task_completed_event(1, json.dumps(10)), + ] + # Second activity scheduled in replay, completed as new event + new_events = [ + helpers.new_task_scheduled_event(2, task.get_name(dummy_activity)), + helpers.new_task_completed_event(2, json.dumps(20)), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # No CLIENT spans during replay — they were emitted in prior dispatches + assert len(client_spans) == 0 + + def test_replayed_activity_failure_no_span(self, otel_setup): + """During a replay dispatch, no CLIENT spans are emitted for + failed activities.""" + def failing_activity(ctx, _): + raise ValueError("boom") + + def orchestrator(ctx: task.OrchestrationContext, _): + try: + yield ctx.call_activity(failing_activity, input=1) + except task.TaskFailedError: + pass + result = yield ctx.call_activity(failing_activity, input=2) + return result + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(failing_activity) + + ex = Exception("boom") + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(failing_activity)), + helpers.new_task_failed_event(1, ex), + ] + new_events = [ + helpers.new_task_scheduled_event(2, task.get_name(failing_activity)), + helpers.new_task_completed_event(2, json.dumps("ok")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # No CLIENT spans during replay + assert len(client_spans) == 0 + + def test_replayed_timer_no_span(self, otel_setup): + """A timer that fired during replay should not emit a timer span.""" + from datetime import timedelta + + def orchestrator(ctx: task.OrchestrationContext, _): + t1 = ctx.current_utc_datetime + timedelta(seconds=1) + yield ctx.create_timer(t1) + t2 = ctx.current_utc_datetime + timedelta(seconds=2) + yield ctx.create_timer(t2) + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + fire_at_1 = start_time + timedelta(seconds=1) + fire_at_2 = start_time + timedelta(seconds=2) + + # First timer created, fired, and second timer created all in old events + old_events = [ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, fire_at_1), + helpers.new_timer_fired_event(1, fire_at_1), + helpers.new_timer_created_event(2, fire_at_2), + ] + # Only the second timer firing is a new event + new_events = [ + helpers.new_timer_fired_event(2, fire_at_2), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # Only the second timer (new event) should produce a span + assert len(client_spans) == 1 + assert "timer" in client_spans[0].name.lower() + + def test_replayed_sub_orchestration_completion_no_span(self, otel_setup): + """During a replay dispatch, no CLIENT spans are emitted for + sub-orchestrations.""" + def sub_orch(ctx: task.OrchestrationContext, _): + return "sub_done" + + def orchestrator(ctx: task.OrchestrationContext, _): + r1 = yield ctx.call_sub_orchestrator(sub_orch) + r2 = yield ctx.call_sub_orchestrator(sub_orch) + return [r1, r2] + + registry = worker._Registry() + sub_name = registry.add_orchestrator(sub_orch) + orch_name = registry.add_orchestrator(orchestrator) + + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(orch_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_sub_orchestration_created_event(1, sub_name, "sub-1", encoded_input=None), + helpers.new_sub_orchestration_completed_event(1, encoded_output=json.dumps("r1")), + helpers.new_sub_orchestration_created_event(2, sub_name, "sub-2", encoded_input=None), + ] + new_events = [ + helpers.new_sub_orchestration_completed_event(2, encoded_output=json.dumps("r2")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # No CLIENT spans during replay + assert len(client_spans) == 0 + + def test_replayed_sub_orchestration_failure_no_span(self, otel_setup): + """During a replay dispatch, no CLIENT spans are emitted for + failed sub-orchestrations.""" + def sub_orch(ctx: task.OrchestrationContext, _): + raise ValueError("sub failed") + + def orchestrator(ctx: task.OrchestrationContext, _): + try: + yield ctx.call_sub_orchestrator(sub_orch) + except task.TaskFailedError: + pass + result = yield ctx.call_sub_orchestrator(sub_orch) + return result + + registry = worker._Registry() + sub_name = registry.add_orchestrator(sub_orch) + orch_name = registry.add_orchestrator(orchestrator) + + ex = Exception("sub failed") + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(orch_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_sub_orchestration_created_event(1, sub_name, "sub-1", encoded_input=None), + helpers.new_sub_orchestration_failed_event(1, ex), + helpers.new_sub_orchestration_created_event(2, sub_name, "sub-2", encoded_input=None), + ] + new_events = [ + helpers.new_sub_orchestration_completed_event(2, encoded_output=json.dumps("ok")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # No CLIENT spans during replay + assert len(client_spans) == 0 + + +class TestOrchestrationSpanLifecycle: + """Tests that orchestration SERVER spans are only emitted on completion + (emit-and-close pattern) — no inter-dispatch storage.""" + + def _get_orch_server_spans(self, exporter): + """Return orchestration SERVER spans from the exporter.""" + return [ + s for s in exporter.get_finished_spans() + if s.kind == trace.SpanKind.SERVER + ] + + def _make_worker_with_registry(self, registry): + """Create a TaskHubGrpcWorker with a pre-populated registry.""" + from unittest.mock import MagicMock + w = worker.TaskHubGrpcWorker(host_address="localhost:4001") + w._registry = registry + return w, MagicMock() + + def test_intermediate_dispatch_does_not_export_span(self, otel_setup): + """An intermediate dispatch (no completeOrchestration) should NOT + export an orchestration SERVER span.""" + from datetime import timedelta + + def orchestrator(ctx: task.OrchestrationContext, _): + due = ctx.current_utc_datetime + timedelta(seconds=1) + yield ctx.create_timer(due) + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + req = pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ) + w._execute_orchestrator(req, stub, "token1") + + # No span exported — orchestration is not yet complete + assert len(self._get_orch_server_spans(otel_setup)) == 0 + + def test_final_dispatch_exports_single_span(self, otel_setup): + """Across multiple dispatches, only one orchestration span should + be exported, and only when the orchestration completes.""" + from datetime import timedelta + + def orchestrator(ctx: task.OrchestrationContext, _): + due = ctx.current_utc_datetime + timedelta(seconds=2) + yield ctx.create_timer(due) + results = yield task.when_all([ + ctx.call_activity(dummy_activity, input=i) + for i in range(3) + ]) + return results + + def dummy_activity(ctx, _): + pass + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + fire_at = start_time + timedelta(seconds=2) + activity_name = task.get_name(dummy_activity) + + # Dispatch 1: start + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ), stub, "t1") + assert len(self._get_orch_server_spans(otel_setup)) == 0 + + # Dispatch 2: timer fires + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, fire_at), + ], + newEvents=[ + helpers.new_timer_fired_event(1, fire_at), + ], + ), stub, "t2") + assert len(self._get_orch_server_spans(otel_setup)) == 0 + + # Dispatch 3: activities complete + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, fire_at), + helpers.new_timer_fired_event(1, fire_at), + helpers.new_task_scheduled_event(2, activity_name), + helpers.new_task_scheduled_event(3, activity_name), + helpers.new_task_scheduled_event(4, activity_name), + ], + newEvents=[ + helpers.new_task_completed_event(2, json.dumps("r1")), + helpers.new_task_completed_event(3, json.dumps("r2")), + helpers.new_task_completed_event(4, json.dumps("r3")), + ], + ), stub, "t3") + + # Exactly one orchestration span exported + orch_spans = self._get_orch_server_spans(otel_setup) + assert len(orch_spans) == 1 + assert "orchestration" in orch_spans[0].name + + def test_error_exports_failed_span(self, otel_setup): + """When an orchestration raises an unhandled error, a span is + exported with ERROR status.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + raise ValueError("orchestration error") + + registry = worker._Registry() + registry.add_orchestrator(orchestrator) + w, stub = self._make_worker_with_registry(registry) + + name = task.get_name(orchestrator) + req = pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ) + w._execute_orchestrator(req, stub, "token1") + + orch_spans = self._get_orch_server_spans(otel_setup) + assert len(orch_spans) == 1 + assert orch_spans[0].status.status_code == StatusCode.ERROR + + def test_separate_instances_get_separate_spans(self, otel_setup): + """Two different orchestration instances should produce independent + spans when each completes.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + instance_a = "inst-a" + instance_b = "inst-b" + + for iid in (instance_a, instance_b): + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=iid, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, iid, encoded_input=None), + ], + ), stub, f"t-{iid}") + + orch_spans = self._get_orch_server_spans(otel_setup) + assert len(orch_spans) == 2 + + def test_initial_dispatch_defers_activity_client_spans(self, otel_setup): + """On the first dispatch, no CLIENT span is emitted because the + span is deferred until the activity completes (taskCompleted / + taskFailed arrives in a later dispatch).""" + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(dummy_activity, input="hello") + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + + # First dispatch — generator runs with is_replaying=False + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ), stub, "t1") + + # No CLIENT span yet — deferred until completion + client_spans = [ + s for s in otel_setup.get_finished_spans() + if s.kind == trace.SpanKind.CLIENT + ] + assert len(client_spans) == 0 + + def test_span_id_consistent_across_dispatches(self, otel_setup): + """The orchestration span ID must be persisted across dispatches + so that child spans (activities, timers) are all parented under + the same orchestration SERVER span.""" + from datetime import timedelta # noqa: F401 + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + r1 = yield ctx.call_activity(dummy_activity, input=1) + r2 = yield ctx.call_activity(dummy_activity, input=2) + return [r1, r2] + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + activity_name = task.get_name(dummy_activity) + + # Dispatch 1: orchestration starts, first activity scheduled + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None, + parent_trace_context=pb.TraceContext( + traceParent=f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_PARENT_SPAN_ID}-01", + spanID=_SAMPLE_PARENT_SPAN_ID, + ), + ), + ], + ), stub, "t1") + + # Capture the orchestration trace context from the response + call_args = stub.CompleteOrchestratorTask.call_args + resp1 = call_args[0][0] + orch_trace_ctx_1 = resp1.orchestrationTraceContext + assert orch_trace_ctx_1.HasField("spanID") + span_id_1 = orch_trace_ctx_1.spanID.value + assert span_id_1 != "" + + otel_setup.clear() + + # Dispatch 2: first activity completes, second activity scheduled + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + orchestrationTraceContext=orch_trace_ctx_1, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None, + parent_trace_context=pb.TraceContext( + traceParent=f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_PARENT_SPAN_ID}-01", + spanID=_SAMPLE_PARENT_SPAN_ID, + ), + ), + helpers.new_task_scheduled_event(1, activity_name), + ], + newEvents=[ + helpers.new_task_completed_event(1, json.dumps(10)), + ], + ), stub, "t2") + + # Capture the orchestration trace context from the second response + call_args = stub.CompleteOrchestratorTask.call_args + resp2 = call_args[0][0] + orch_trace_ctx_2 = resp2.orchestrationTraceContext + assert orch_trace_ctx_2.HasField("spanID") + span_id_2 = orch_trace_ctx_2.spanID.value + + # The span ID must be the same across dispatches + assert span_id_1 == span_id_2 + + def test_child_spans_parented_under_orchestrator_span(self, otel_setup): + """Activities and timers should be parented under the orchestration + SERVER span, and the orchestrator span ID must be consistent across + dispatches.""" + from datetime import timedelta + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + due = ctx.current_utc_datetime + timedelta(seconds=1) + yield ctx.create_timer(due) + r1 = yield ctx.call_activity(dummy_activity, input=1) + return r1 + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + fire_at = start_time + timedelta(seconds=1) + activity_name = task.get_name(dummy_activity) + + parent_tp = f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_PARENT_SPAN_ID}-01" + parent_ctx = pb.TraceContext( + traceParent=parent_tp, + spanID=_SAMPLE_PARENT_SPAN_ID, + ) + + # Dispatch 1: start, timer scheduled + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None, + parent_trace_context=parent_ctx, + ), + ], + ), stub, "t1") + + call_args = stub.CompleteOrchestratorTask.call_args + resp1 = call_args[0][0] + orch_trace_ctx = resp1.orchestrationTraceContext + orch_span_id = orch_trace_ctx.spanID.value + otel_setup.clear() + + # Dispatch 2: timer fires, activity scheduled + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + orchestrationTraceContext=orch_trace_ctx, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None, + parent_trace_context=parent_ctx, + ), + helpers.new_timer_created_event(1, fire_at), + ], + newEvents=[ + helpers.new_timer_fired_event(1, fire_at), + ], + ), stub, "t2") + + # Timer span should be parented under the orchestration span + timer_spans = [ + s for s in otel_setup.get_finished_spans() + if s.kind == trace.SpanKind.INTERNAL and "timer" in s.name + ] + assert len(timer_spans) == 1 + assert timer_spans[0].parent is not None + assert timer_spans[0].parent.span_id == int(orch_span_id, 16) + + call_args = stub.CompleteOrchestratorTask.call_args + resp2 = call_args[0][0] + orch_trace_ctx_2 = resp2.orchestrationTraceContext + # Span ID must be consistent + assert orch_trace_ctx_2.spanID.value == orch_span_id + otel_setup.clear() + + # Dispatch 3: activity completes, orchestration finishes + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + orchestrationTraceContext=orch_trace_ctx_2, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None, + parent_trace_context=parent_ctx, + ), + helpers.new_timer_created_event(1, fire_at), + helpers.new_timer_fired_event(1, fire_at), + helpers.new_task_scheduled_event(2, activity_name), + ], + newEvents=[ + helpers.new_task_completed_event(2, json.dumps("result")), + ], + ), stub, "t3") + + # Orchestration SERVER span should use the same span ID + orch_spans = self._get_orch_server_spans(otel_setup) + assert len(orch_spans) == 1 + assert orch_spans[0].context.span_id == int(orch_span_id, 16) + + # Orchestration should be parented under the PRODUCER span + assert orch_spans[0].parent is not None + assert orch_spans[0].parent.span_id == int(_SAMPLE_PARENT_SPAN_ID, 16) + + +# --------------------------------------------------------------------------- +# Tests for _parse_traceparent +# --------------------------------------------------------------------------- + + +class TestParseTraceparent: + """Tests for tracing._parse_traceparent.""" + + def test_valid_traceparent(self): + tp = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + result = tracing._parse_traceparent(tp) + assert result is not None + trace_id, span_id, flags = result + assert trace_id == int("0af7651916cd43dd8448eb211c80319c", 16) + assert span_id == int("b7ad6b7169203331", 16) + assert flags == 1 + + def test_invalid_format(self): + assert tracing._parse_traceparent("not-a-traceparent") is None + + def test_too_few_parts(self): + assert tracing._parse_traceparent("00-abc") is None + + def test_zero_trace_id(self): + tp = "00-00000000000000000000000000000000-b7ad6b7169203331-01" + assert tracing._parse_traceparent(tp) is None + + def test_zero_span_id(self): + tp = "00-0af7651916cd43dd8448eb211c80319c-0000000000000000-01" + assert tracing._parse_traceparent(tp) is None + + def test_non_hex_values(self): + tp = "00-zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz-b7ad6b7169203331-01" + assert tracing._parse_traceparent(tp) is None + + def test_flags_zero(self): + tp = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-00" + result = tracing._parse_traceparent(tp) + assert result is not None + assert result[2] == 0 + + +# --------------------------------------------------------------------------- +# Tests for generate_client_trace_context +# --------------------------------------------------------------------------- + + +class TestGenerateClientTraceContext: + """Tests for tracing.generate_client_trace_context.""" + + def test_returns_none_without_parent(self): + assert tracing.generate_client_trace_context(None) is None + + def test_returns_none_with_invalid_parent(self): + ctx = pb.TraceContext(traceParent="invalid", spanID="bad") + assert tracing.generate_client_trace_context(ctx) is None + + def test_generates_valid_traceparent(self): + parent = pb.TraceContext( + traceParent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + spanID="b7ad6b7169203331", + ) + result = tracing.generate_client_trace_context(parent) + assert result is not None + assert result.traceParent != "" + assert result.spanID != "" + + def test_preserves_trace_id(self): + parent = pb.TraceContext( + traceParent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + spanID="b7ad6b7169203331", + ) + result = tracing.generate_client_trace_context(parent) + assert result is not None + parts = result.traceParent.split("-") + assert parts[1] == "0af7651916cd43dd8448eb211c80319c" + + def test_generates_different_span_id(self): + parent = pb.TraceContext( + traceParent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + spanID="b7ad6b7169203331", + ) + result = tracing.generate_client_trace_context(parent) + assert result is not None + assert result.spanID != parent.spanID + + def test_span_id_matches_traceparent(self): + parent = pb.TraceContext( + traceParent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + spanID="b7ad6b7169203331", + ) + result = tracing.generate_client_trace_context(parent) + assert result is not None + parts = result.traceParent.split("-") + assert parts[2] == result.spanID + + @patch.object(tracing, '_OTEL_AVAILABLE', False) + def test_returns_none_without_otel(self): + parent = pb.TraceContext( + traceParent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + spanID="b7ad6b7169203331", + ) + assert tracing.generate_client_trace_context(parent) is None + + +# --------------------------------------------------------------------------- +# Tests for emit_client_span +# --------------------------------------------------------------------------- + + +class TestEmitClientSpan: + """Tests for tracing.emit_client_span.""" + + def test_emits_span_with_correct_attributes(self, otel_setup: InMemorySpanExporter): + tracing.emit_client_span( + "activity", "SayHello", "inst-1", task_id=5, + client_trace_context=_make_client_trace_ctx(), + parent_trace_context=_make_parent_trace_ctx(), + start_time_ns=1_000_000_000, + end_time_ns=2_000_000_000, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "activity:SayHello" + assert s.kind == trace.SpanKind.CLIENT + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "activity" + assert s.attributes[tracing.ATTR_TASK_NAME] == "SayHello" + assert s.attributes[tracing.ATTR_TASK_INSTANCE_ID] == "inst-1" + assert s.attributes[tracing.ATTR_TASK_TASK_ID] == "5" + + def test_span_has_correct_trace_and_span_ids(self, otel_setup: InMemorySpanExporter): + tracing.emit_client_span( + "activity", "Act", "inst-1", task_id=1, + client_trace_context=_make_client_trace_ctx(), + parent_trace_context=_make_parent_trace_ctx(), + start_time_ns=1_000_000_000, + end_time_ns=2_000_000_000, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.context is not None + assert s.context.trace_id == int(_SAMPLE_TRACE_ID, 16) + assert s.context.span_id == int(_SAMPLE_CLIENT_SPAN_ID, 16) + + def test_span_parent_matches_parent_trace_context(self, otel_setup: InMemorySpanExporter): + tracing.emit_client_span( + "activity", "Act", "inst-1", task_id=1, + client_trace_context=_make_client_trace_ctx(), + parent_trace_context=_make_parent_trace_ctx(), + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.parent is not None + assert s.parent.span_id == int(_SAMPLE_PARENT_SPAN_ID, 16) + + def test_error_span(self, otel_setup: InMemorySpanExporter): + tracing.emit_client_span( + "activity", "Act", "inst-1", task_id=1, + client_trace_context=_make_client_trace_ctx(), + parent_trace_context=_make_parent_trace_ctx(), + is_error=True, + error_message="task failed", + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert spans[0].status.description is not None + assert "task failed" in spans[0].status.description + + def test_custom_timestamps(self, otel_setup: InMemorySpanExporter): + tracing.emit_client_span( + "activity", "Act", "inst-1", task_id=1, + client_trace_context=_make_client_trace_ctx(), + start_time_ns=5_000_000_000, + end_time_ns=10_000_000_000, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].start_time == 5_000_000_000 + assert spans[0].end_time == 10_000_000_000 + + def test_version_included(self, otel_setup: InMemorySpanExporter): + tracing.emit_client_span( + "orchestration", "SubOrch", "inst-1", task_id=1, + client_trace_context=_make_client_trace_ctx(), + version="2.0", + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "orchestration:SubOrch@(2.0)" + assert spans[0].attributes is not None + assert spans[0].attributes[tracing.ATTR_TASK_VERSION] == "2.0" + + def test_noop_with_invalid_client_trace_context(self, otel_setup: InMemorySpanExporter): + tracing.emit_client_span( + "activity", "Act", "inst-1", task_id=1, + client_trace_context=pb.TraceContext(traceParent="bad"), + ) + assert len(otel_setup.get_finished_spans()) == 0 + + @patch.object(tracing, '_OTEL_AVAILABLE', False) + def test_noop_without_otel(self, otel_setup: InMemorySpanExporter): + tracing.emit_client_span( + "activity", "Act", "inst-1", task_id=1, + client_trace_context=_make_client_trace_ctx(), + ) + assert len(otel_setup.get_finished_spans()) == 0 + + +# --------------------------------------------------------------------------- +# Integration tests for deferred CLIENT span lifecycle +# --------------------------------------------------------------------------- + + +class TestDeferredClientSpanIntegration: + """End-to-end tests verifying that CLIENT spans are emitted with proper + timestamps when taskCompleted / taskFailed / sub-orchestration events + arrive as new events.""" + + def _make_traced_task_scheduled_event( + self, event_id, name, client_traceparent, timestamp_seconds=100, + ): + """Build a taskScheduled event with parentTraceContext and timestamp.""" + ts = timestamp_pb2.Timestamp() + ts.FromSeconds(timestamp_seconds) + return pb.HistoryEvent( + eventId=event_id, + timestamp=ts, + taskScheduled=pb.TaskScheduledEvent( + name=name, + parentTraceContext=pb.TraceContext( + traceParent=client_traceparent, + spanID=client_traceparent.split("-")[2], + ), + ), + ) + + def _make_traced_task_completed_event(self, task_id, result, timestamp_seconds=200): + ts = timestamp_pb2.Timestamp() + ts.FromSeconds(timestamp_seconds) + return pb.HistoryEvent( + eventId=-1, + timestamp=ts, + taskCompleted=pb.TaskCompletedEvent( + taskScheduledId=task_id, + result=wrappers_pb2.StringValue(value=result) if result else None, + ), + ) + + def _make_traced_task_failed_event(self, task_id, error_msg, timestamp_seconds=200): + ts = timestamp_pb2.Timestamp() + ts.FromSeconds(timestamp_seconds) + return pb.HistoryEvent( + eventId=-1, + timestamp=ts, + taskFailed=pb.TaskFailedEvent( + taskScheduledId=task_id, + failureDetails=pb.TaskFailureDetails( + errorMessage=error_msg, errorType="TaskFailedError", + ), + ), + ) + + def _make_execution_started_with_trace(self, name, instance_id): + parent_tp = f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_PARENT_SPAN_ID}-01" + return pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + executionStarted=pb.ExecutionStartedEvent( + name=name, + orchestrationInstance=pb.OrchestrationInstance(instanceId=instance_id), + parentTraceContext=pb.TraceContext( + traceParent=parent_tp, + spanID=_SAMPLE_PARENT_SPAN_ID, + ), + ), + ) + + def test_activity_completed_emits_client_span(self, otel_setup: InMemorySpanExporter): + """When taskCompleted arrives as a new event, a CLIENT span is + emitted with start_time=taskScheduled.timestamp and + end_time=taskCompleted.timestamp.""" + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.call_activity(dummy_activity, input="hi") + return result + + registry = worker._Registry() + orch_name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + act_name = task.get_name(dummy_activity) + + client_tp = f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_CLIENT_SPAN_ID}-01" + + # Dispatch 2: old events replay the scheduling, new event completes + old_events = [ + helpers.new_orchestrator_started_event(), + self._make_execution_started_with_trace(orch_name, TEST_INSTANCE_ID), + self._make_traced_task_scheduled_event(1, act_name, client_tp, timestamp_seconds=100), + ] + new_events = [ + self._make_traced_task_completed_event(1, json.dumps("result"), timestamp_seconds=200), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = [ + s for s in otel_setup.get_finished_spans() + if s.kind == trace.SpanKind.CLIENT + ] + assert len(client_spans) == 1 + s = client_spans[0] + assert "activity" in s.name + assert s.context is not None + assert s.context.span_id == int(_SAMPLE_CLIENT_SPAN_ID, 16) + assert s.context.trace_id == int(_SAMPLE_TRACE_ID, 16) + assert s.start_time == 100_000_000_000 # 100 seconds in ns + assert s.end_time == 200_000_000_000 # 200 seconds in ns + # Parent should be the orchestration span, not the PRODUCER span + assert s.parent is not None + orch_ctx = result._orchestration_trace_context + assert orch_ctx is not None + assert s.parent.span_id == int(orch_ctx.spanID, 16) + assert s.parent.span_id != int(_SAMPLE_PARENT_SPAN_ID, 16) + + def test_activity_failed_emits_error_client_span(self, otel_setup: InMemorySpanExporter): + """When taskFailed arrives as a new event, a CLIENT span is + emitted with ERROR status.""" + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + try: + yield ctx.call_activity(dummy_activity, input="hi") + except task.TaskFailedError: + return "caught" + + registry = worker._Registry() + orch_name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + act_name = task.get_name(dummy_activity) + + client_tp = f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_CLIENT_SPAN_ID}-01" + + old_events = [ + helpers.new_orchestrator_started_event(), + self._make_execution_started_with_trace(orch_name, TEST_INSTANCE_ID), + self._make_traced_task_scheduled_event(1, act_name, client_tp, timestamp_seconds=100), + ] + new_events = [ + self._make_traced_task_failed_event(1, "boom", timestamp_seconds=250), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = [ + s for s in otel_setup.get_finished_spans() + if s.kind == trace.SpanKind.CLIENT + ] + assert len(client_spans) == 1 + s = client_spans[0] + assert s.status.status_code == StatusCode.ERROR + assert s.status.description is not None + assert "boom" in s.status.description + assert s.start_time == 100_000_000_000 + assert s.end_time == 250_000_000_000 + + def test_sub_orchestration_completed_emits_client_span(self, otel_setup: InMemorySpanExporter): + """When subOrchestrationInstanceCompleted arrives as a new event, + a CLIENT span is emitted.""" + + def sub_orch(ctx: task.OrchestrationContext, _): + return "sub_result" + + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.call_sub_orchestrator(sub_orch) + return result + + registry = worker._Registry() + sub_name = registry.add_orchestrator(sub_orch) + orch_name = registry.add_orchestrator(orchestrator) + + client_tp = f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_CLIENT_SPAN_ID}-01" + sub_instance_id = f"{TEST_INSTANCE_ID}:0001" + + ts_created = timestamp_pb2.Timestamp() + ts_created.FromSeconds(150) + ts_completed = timestamp_pb2.Timestamp() + ts_completed.FromSeconds(300) + + old_events = [ + helpers.new_orchestrator_started_event(), + self._make_execution_started_with_trace(orch_name, TEST_INSTANCE_ID), + pb.HistoryEvent( + eventId=1, + timestamp=ts_created, + subOrchestrationInstanceCreated=pb.SubOrchestrationInstanceCreatedEvent( + name=sub_name, + instanceId=sub_instance_id, + parentTraceContext=pb.TraceContext( + traceParent=client_tp, + spanID=_SAMPLE_CLIENT_SPAN_ID, + ), + ), + ), + ] + new_events = [ + pb.HistoryEvent( + eventId=-1, + timestamp=ts_completed, + subOrchestrationInstanceCompleted=pb.SubOrchestrationInstanceCompletedEvent( + taskScheduledId=1, + result=wrappers_pb2.StringValue(value=json.dumps("sub_result")), + ), + ), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = [ + s for s in otel_setup.get_finished_spans() + if s.kind == trace.SpanKind.CLIENT + ] + assert len(client_spans) == 1 + s = client_spans[0] + assert "orchestration" in s.name + assert s.context is not None + assert s.context.span_id == int(_SAMPLE_CLIENT_SPAN_ID, 16) + assert s.start_time == 150_000_000_000 + assert s.end_time == 300_000_000_000 + + def test_replayed_completion_does_not_emit_client_span(self, otel_setup: InMemorySpanExporter): + """When both taskScheduled and taskCompleted are in old_events + (full replay), no CLIENT span is emitted.""" + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + r1 = yield ctx.call_activity(dummy_activity, input=1) + r2 = yield ctx.call_activity(dummy_activity, input=2) + return [r1, r2] + + registry = worker._Registry() + orch_name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + act_name = task.get_name(dummy_activity) + + client_tp = f"00-{_SAMPLE_TRACE_ID}-{_SAMPLE_CLIENT_SPAN_ID}-01" + + old_events = [ + helpers.new_orchestrator_started_event(), + self._make_execution_started_with_trace(orch_name, TEST_INSTANCE_ID), + self._make_traced_task_scheduled_event(1, act_name, client_tp, timestamp_seconds=100), + self._make_traced_task_completed_event(1, json.dumps(10), timestamp_seconds=200), + ] + # Second activity completes as new event — but NO parentTraceContext + # on its taskScheduled, so no CLIENT span for it either + new_events = [ + helpers.new_task_scheduled_event(2, act_name), + helpers.new_task_completed_event(2, json.dumps(20)), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = [ + s for s in otel_setup.get_finished_spans() + if s.kind == trace.SpanKind.CLIENT + ] + # taskCompleted(1) in old_events -> replaying -> no span + # taskCompleted(2) in new_events -> no parentTraceContext on taskScheduled -> no span + assert len(client_spans) == 0