From 3ae673e8d8e8219824a90adeef8e46bec51da78d Mon Sep 17 00:00:00 2001 From: Jorge Lisa <64639359+AcquaDiGiorgio@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:06:47 +0100 Subject: [PATCH] feat(otel): proof of concept for tracing using decorators note: metrics are untested --- .../src/diracx/routers/auth/utils.py | 19 +++- .../src/diracx/routers/auth/well_known.py | 6 +- diracx-routers/src/diracx/routers/otel.py | 90 +++++++++++++++++++ 3 files changed, 112 insertions(+), 3 deletions(-) diff --git a/diracx-routers/src/diracx/routers/auth/utils.py b/diracx-routers/src/diracx/routers/auth/utils.py index 3b8813611..48ac1a5d3 100644 --- a/diracx-routers/src/diracx/routers/auth/utils.py +++ b/diracx-routers/src/diracx/routers/auth/utils.py @@ -13,6 +13,8 @@ from cryptography.fernet import Fernet from fastapi import Depends, HTTPException, status +from ..otel import async_tracer, sync_tracer, set_trace_attribute + from diracx.core.properties import ( SecurityProperty, UnevaluatedProperty, @@ -59,7 +61,7 @@ async def require_property( _server_metadata_cache: TTLCache = TTLCache(maxsize=1024, ttl=3600) - +@async_tracer() async def get_server_metadata(url: str): """Get the server metadata from the IAM.""" server_metadata = _server_metadata_cache.get(url) @@ -71,6 +73,8 @@ async def get_server_metadata(url: str): raise NotImplementedError(res) server_metadata = res.json() _server_metadata_cache[url] = server_metadata + + set_trace_attribute("meta", server_metadata) return server_metadata @@ -161,7 +165,7 @@ async def verify_dirac_refresh_token( return (token["jti"], float(token["exp"]), token["legacy_exchange"]) - +@sync_tracer() def parse_and_validate_scope( scope: str, config: Config, available_properties: set[SecurityProperty] ) -> ScopeInfoDict: @@ -224,6 +228,10 @@ def parse_and_validate_scope( f"{set(properties)-set(available_properties)} are not valid properties" ) + set_trace_attribute("group", group) + set_trace_attribute("properties", set(sorted(properties))) + set_trace_attribute("vo", vo) + return { "group": group, "properties": set(sorted(properties)), @@ -231,6 +239,7 @@ def parse_and_validate_scope( } +@async_tracer() async def initiate_authorization_flow_with_iam( config, vo: str, redirect_uri: str, state: dict[str, str], cipher_suite: Fernet ): @@ -272,6 +281,12 @@ async def initiate_authorization_flow_with_iam( f"state={encrypted_state}", ] authorization_flow_url = f"{authorization_endpoint}?{'&'.join(urlParams)}" + + set_trace_attribute("endpoint", authorization_endpoint) + for param in urlParams: + k, v = param.split("=", maxsplit=1) + set_trace_attribute(k, v) + return authorization_flow_url diff --git a/diracx-routers/src/diracx/routers/auth/well_known.py b/diracx-routers/src/diracx/routers/auth/well_known.py index 94b6d11a1..c3a16355a 100644 --- a/diracx-routers/src/diracx/routers/auth/well_known.py +++ b/diracx-routers/src/diracx/routers/auth/well_known.py @@ -6,6 +6,7 @@ from ..dependencies import Config, DevelopmentSettings from ..fastapi_classes import DiracxRouter from ..utils.users import AuthSettings +from ..otel import async_tracer, set_trace_attribute router = DiracxRouter(require_auth=False, path_root="") @@ -24,7 +25,7 @@ async def openid_configuration( scopes_supported += [f"group:{vo}" for vo in config.Registry[vo].Groups] scopes_supported += [f"property:{p}" for p in settings.available_properties] - return { + res = { "issuer": settings.token_issuer, "token_endpoint": str(request.url_for("token")), "userinfo_endpoint:": str(request.url_for("userinfo")), @@ -43,6 +44,9 @@ async def openid_configuration( "code_challenge_methods_supported": ["S256"], } + set_trace_attribute("config", res) + + return res class SupportInfo(TypedDict): message: str diff --git a/diracx-routers/src/diracx/routers/otel.py b/diracx-routers/src/diracx/routers/otel.py index cddf601e5..4a86be61c 100644 --- a/diracx-routers/src/diracx/routers/otel.py +++ b/diracx-routers/src/diracx/routers/otel.py @@ -23,6 +23,96 @@ from diracx.core.settings import ServiceSettingsBase +from functools import wraps +import inspect + +from collections import UserDict +from timeit import default_timer as timer + + +def async_tracer(name=None): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + # Obtain the module that contains the decorated function + package_name = get_module_name_from_func(func) + + if not name: + tace_name = func.__name__ + + tracer = trace.get_tracer_provider().get_tracer(package_name) + + # Create a span with name: diracx.diracx_xxx.(...).package.function + with tracer.start_as_current_span(f"{package_name}.{tace_name}"): + return await func(*args, **kwargs) + + return wrapper + return decorator + +def sync_tracer(name=None): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Obtain the module that contains the decorated function + module_name = get_module_name_from_func(func) + + if not name: + tace_name = func.__name__ + + tracer = trace.get_tracer_provider().get_tracer(module_name) + + # Create a span with name: diracx.diracx_xxx.(...).package.function + with tracer.start_as_current_span(f"{module_name}.{tace_name}"): + return func(*args, **kwargs) + + return wrapper + return decorator + +def set_trace_attribute(key, value, stringify=False): + span = trace.get_current_span() + if stringify: + span.set_attribute(f"diracx.{key}", str(value)) + else: + _recursive_set_trace_attribute(span, f"diracx.{key}", value) + + +def _recursive_set_trace_attribute(span, key, value): + if isinstance(value, list): + zeros = len(str(len(value))) + for idx, item in enumerate(value): + _recursive_set_trace_attribute(span, f"{key}[{str(idx).zfill(zeros)}]", item) + + elif isinstance(value, set) or isinstance(value, tuple): + zeros = len(str(len(value))) + for idx, item in enumerate(value): + _recursive_set_trace_attribute(span, f"{key}.item_{str(idx).zfill(zeros)}", item) + + elif isinstance(value, dict) or isinstance(value, UserDict): + for k, v in value.items(): + _recursive_set_trace_attribute(span, f"{key}.{k}", v) + + else: + span.set_attribute(key, value) + + +def increase_counter(meter_name, counter_name, amount=1, is_updown=False): + meter = metrics.get_meter_provider().get_meter(meter_name) + + if is_updown: + metric = meter.create_up_down_counter(counter_name) + else: + metric = meter.create_counter(counter_name) + + metric.add(amount) + +def get_module_name_from_func(func): + from_module = inspect.getmodule(func) + + if not from_module: + return "diracx" + + module_name = from_module.__name__ + return module_name class OTELSettings(ServiceSettingsBase): """Settings for the Open Telemetry Configuration."""