diff --git a/pyproject.toml b/pyproject.toml index 8d8313ab..2f61f6d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] license = { text = "BSD-3-Clause" } requires-python = ">=3.10" -dependencies = ["bidict", "pika", "setuptools", "stomp-py>=7"] +dependencies = ["bidict", "pika", "setuptools", "stomp-py>=7", "opentelemetry-api==1.20.0", "opentelemetry-sdk==1.20.0", "opentelemetry-exporter-otlp-proto-http==1.20.0" ] [project.urls] Download = "https://github.com/DiamondLightSource/python-workflows/releases" @@ -53,6 +53,7 @@ OfflineTransport = "workflows.transport.offline_transport:OfflineTransport" pika = "workflows.util.zocalo.configuration:Pika" stomp = "workflows.util.zocalo.configuration:Stomp" transport = "workflows.util.zocalo.configuration:DefaultTransport" +opentelemetry = "workflows.util.zocalo.configuration:OTEL" [project.scripts] "workflows.validate_recipe" = "workflows.recipe.validate:main" diff --git a/requirements_dev.txt b/requirements_dev.txt index 8207c45b..bd711be7 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -7,3 +7,7 @@ pytest-mock==3.14.0 pytest-timeout==2.3.1 stomp-py==8.1.2 websocket-client==1.8.0 +opentelemetry-api==1.20.0 +opentelemetry-sdk==1.20.0 +opentelemetry-exporter-otlp-proto-http==1.20.0 +marshmallow \ No newline at end of file diff --git a/src/workflows/recipe/__init__.py b/src/workflows/recipe/__init__.py index 0f1973f4..73bba19a 100644 --- a/src/workflows/recipe/__init__.py +++ b/src/workflows/recipe/__init__.py @@ -3,8 +3,11 @@ import functools import logging from collections.abc import Callable +from contextlib import ExitStack from typing import Any +from opentelemetry import trace + from workflows.recipe.recipe import Recipe from workflows.recipe.validate import validate_recipe from workflows.recipe.wrapper import RecipeWrapper @@ -68,11 +71,52 @@ def unwrap_recipe(header, message): if mangle_for_receiving: message = mangle_for_receiving(message) if header.get("workflows-recipe") in {True, "True", "true", 1}: + otel_logs = None rw = RecipeWrapper(message=message, transport=transport_layer) - if log_extender and rw.environment and rw.environment.get("ID"): - with log_extender("recipe_ID", rw.environment["ID"]): + + if hasattr(rw, "environment") and rw.environment.get("ID"): + # Extract recipe ID from environment and add to current span + span = trace.get_current_span() + recipe_id = rw.environment.get("ID") + + if recipe_id: + span.set_attribute("recipe_id", recipe_id) + + # Extract span_id and trace_id for logging + span_context = span.get_span_context() + if span_context and span_context.is_valid: + span_id = span_context.span_id + trace_id = span_context.trace_id + + otel_logs = { + "span_id": span_id, + "trace_id": trace_id, + } + + if recipe_id: + otel_logs["recipe_id"] = recipe_id + + with ExitStack() as stack: + # Configure the context depending on if service is emitting spans + if ( + otel_logs + and log_extender + and rw.environment + and rw.environment.get("ID") + ): + stack.enter_context( + log_extender("recipe_ID", rw.environment.get("ID")) + ) + stack.enter_context(log_extender("otel_logs", otel_logs)) + elif log_extender and rw.environment and rw.environment.get("ID"): + stack.enter_context( + log_extender("recipe_ID", rw.environment.get("ID")) + ) + return callback(rw, header, message.get("payload")) + return callback(rw, header, message.get("payload")) + if allow_non_recipe_messages: return callback(None, header, message) # self.log.warning('Discarding non-recipe message:\n' + \ diff --git a/src/workflows/services/common_service.py b/src/workflows/services/common_service.py index de2ef704..1f09306e 100644 --- a/src/workflows/services/common_service.py +++ b/src/workflows/services/common_service.py @@ -9,8 +9,15 @@ import time from typing import Any +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + import workflows import workflows.logging +from workflows.transport.middleware.otel_tracing import OTELTracingMiddleware class Status(enum.Enum): @@ -185,6 +192,40 @@ def start_transport(self): self.transport.subscription_callback_set_intercept( self._transport_interceptor ) + + # Configure OTELTracing if configuration is available + otel_config = ( + self.config._opentelemetry + if self.config and hasattr(self.config, "opentelemetry") + else None + ) + if otel_config: + # Configure OTELTracing + resource = Resource.create( + { + SERVICE_NAME: self._service_name, + } + ) + + self.log.debug("Configuring OTELTracing") + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + + # Configure BatchProcessor and OTLPSpanExporter using config values + otlp_exporter = OTLPSpanExporter( + endpoint=otel_config["endpoint"], + timeout=otel_config.get("timeout", 10), + ) + span_processor = BatchSpanProcessor(otlp_exporter) + provider.add_span_processor(span_processor) + + # Add OTELTracingMiddleware to the transport layer + tracer = trace.get_tracer(__name__) + otel_middleware = OTELTracingMiddleware( + tracer, service_name=self._service_name + ) + self._transport.add_middleware(otel_middleware) + metrics = self._environment.get("metrics") if metrics: import prometheus_client diff --git a/src/workflows/transport/middleware/__init__.py b/src/workflows/transport/middleware/__init__.py index 1ace0ff0..aeb60fbb 100644 --- a/src/workflows/transport/middleware/__init__.py +++ b/src/workflows/transport/middleware/__init__.py @@ -233,6 +233,10 @@ def wrapped_callback(header, message): def wrap(f: Callable): + # debugging + if f.__name__ == "send": + print("we are wrapping send now") + @functools.wraps(f) def wrapper(self, *args, **kwargs): return functools.reduce( @@ -243,4 +247,5 @@ def wrapper(self, *args, **kwargs): lambda *args, **kwargs: f(self, *args, **kwargs), )(*args, **kwargs) + print(wrapper.__wrapped__) return wrapper diff --git a/src/workflows/transport/middleware/otel_tracing.py b/src/workflows/transport/middleware/otel_tracing.py new file mode 100644 index 00000000..13cd82e2 --- /dev/null +++ b/src/workflows/transport/middleware/otel_tracing.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import functools +import json +from collections.abc import Callable + +from opentelemetry import trace +from opentelemetry.context import Context +from opentelemetry.propagate import extract, inject + +from workflows.transport.common_transport import MessageCallback, TemporarySubscription + + +class OTELTracingMiddleware: + def __init__(self, tracer: trace.Tracer, service_name: str): + self.tracer = tracer + self.service_name = service_name + + def send(self, call_next: Callable, destination: str, message, **kwargs): + # Get current span context (may be None if this is the root span) + current_span = trace.get_current_span() + parent_context = ( + trace.set_span_in_context(current_span) if current_span else None + ) + + with self.tracer.start_as_current_span( + "transport.send", + context=parent_context, + ) as span: + span.set_attribute("service_name", self.service_name) + + span.set_attribute("message", json.dumps(message)) + span.set_attribute("destination", destination) + print("parent_context is...", parent_context) + + # Inject the current trace context into the message headers + headers = kwargs.get("headers", {}) + if headers is None: + headers = {} + inject(headers) # This modifies headers in-place + kwargs["headers"] = headers + + return call_next(destination, message, **kwargs) + + def subscribe( + self, call_next: Callable, channel: str, callback: Callable, **kwargs + ) -> int: + @functools.wraps(callback) + def wrapped_callback(header, message): + # Extract trace context from message headers + ctx = extract(header) if header else Context() + + # Start a new span with the extracted context + with self.tracer.start_as_current_span( + "transport.subscribe", + context=ctx, + ) as span: + span.set_attribute("service_name", self.service_name) + + span.set_attribute("message", json.dumps(message)) + span.set_attribute("channel", channel) + + # Call the original callback - this will process the message + # and potentially call send() which will pick up this context + return callback(header, message) + + return call_next(channel, wrapped_callback, **kwargs) + + def subscribe_broadcast( + self, call_next: Callable, channel: str, callback: Callable, **kwargs + ) -> int: + @functools.wraps(callback) + def wrapped_callback(header, message): + # Extract trace context from message headers + ctx = extract(header) if header else Context() + + # # Start a new span with the extracted context + with self.tracer.start_as_current_span( + "transport.subscribe_broadcast", + context=ctx, + ) as span: + span.set_attribute("service_name", self.service_name) + + span.set_attribute("message", json.dumps(message)) + span.set_attribute("channel", channel) + + return callback(header, message) + + return call_next(channel, wrapped_callback, **kwargs) + + def subscribe_temporary( + self, + call_next: Callable, + channel_hint: str | None, + callback: MessageCallback, + **kwargs, + ) -> TemporarySubscription: + @functools.wraps(callback) + def wrapped_callback(header, message): + # Extract trace context from message headers + ctx = extract(header) if header else Context() + + # Start a new span with the extracted context + with self.tracer.start_as_current_span( + "transport.subscribe_temporary", + context=ctx, + ) as span: + span.set_attribute("service_name", self.service_name) + + span.set_attribute("message", json.dumps(message)) + if channel_hint: + span.set_attribute("channel_hint", channel_hint) + + return callback(header, message) + + return call_next(channel_hint, wrapped_callback, **kwargs) + + def unsubscribe( + self, + call_next: Callable, + subscription: int, + drop_callback_reference=False, + **kwargs, + ): + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transport.unsubscribe", + context=current_context, + ) as span: + span.set_attribute("service_name", self.service_name) + span.set_attribute("subscription_id", subscription) + + call_next( + subscription, drop_callback_reference=drop_callback_reference, **kwargs + ) + + def ack( + self, + call_next: Callable, + message, + subscription_id: int | None = None, + **kwargs, + ): + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transport.ack", + context=current_context, + ) as span: + span.set_attribute("service_name", self.service_name) + span.set_attribute("message", json.dumps(message)) + if subscription_id: + span.set_attribute("subscription_id", subscription_id) + + call_next(message, subscription_id=subscription_id, **kwargs) + + def nack( + self, + call_next: Callable, + message, + subscription_id: int | None = None, + **kwargs, + ): + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transport.nack", + context=current_context, + ) as span: + span.set_attribute("service_name", self.service_name) + + span.set_attribute("message", json.dumps(message)) + if subscription_id: + span.set_attribute("subscription_id", subscription_id) + + call_next(message, subscription_id=subscription_id, **kwargs) + + def transaction_begin( + self, call_next: Callable, subscription_id: int | None = None, **kwargs + ) -> int: + """Start a new transaction span""" + # Get current span context (may be None if this is the root span) + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transaction.begin", + context=current_context, + ) as span: + span.set_attribute("service_name", self.service_name) + + if subscription_id: + span.set_attribute("subscription_id", subscription_id) + + return call_next(subscription_id=subscription_id, **kwargs) + + def transaction_abort( + self, call_next: Callable, transaction_id: int | None = None, **kwargs + ): + """Abort a transaction span""" + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transaction.abort", + context=current_context, + ) as span: + span.set_attribute("service_name", self.service_name) + + if transaction_id: + span.set_attribute("transaction_id", transaction_id) + + call_next(transaction_id=transaction_id, **kwargs) + + def transaction_commit( + self, call_next: Callable, transaction_id: int | None = None, **kwargs + ): + """Commit a transaction span""" + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transaction.commit", + context=current_context, + ) as span: + span.set_attribute("service_name", self.service_name) + if transaction_id: + span.set_attribute("transaction_id", transaction_id) + + call_next(transaction_id=transaction_id, **kwargs) diff --git a/src/workflows/util/zocalo/configuration.py b/src/workflows/util/zocalo/configuration.py index 08a600aa..48e52962 100644 --- a/src/workflows/util/zocalo/configuration.py +++ b/src/workflows/util/zocalo/configuration.py @@ -8,6 +8,34 @@ from workflows.transport.stomp_transport import StompTransport +class OTEL: + """A Zocalo configuration plugin to pre-populate OTELTracing config defaults""" + + class Schema(PluginSchema): + host = fields.Str(required=True) + port = fields.Int(required=True) + endpoint = fields.Str(required=False) + timeout = fields.Int(required=False, load_default=10) + + # Store configuration for access by services + config = {} + + @staticmethod + def activate(configuration): + # Build the full endpoint URL if not provided + if "endpoint" not in configuration: + endpoint = ( + f"https://{configuration['host']}:{configuration['port']}/v1/traces" + ) + else: + endpoint = configuration["endpoint"] + + OTEL.config["endpoint"] = endpoint + OTEL.config["timeout"] = configuration.get("timeout", 10) + + return OTEL.config + + class Stomp: """A Zocalo configuration plugin to pre-populate StompTransport config defaults"""