11from __future__ import annotations
22
33import functools
4+ import json
45from collections .abc import Callable
56
67from opentelemetry import trace
7- from opentelemetry .propagate import extract , inject
88from opentelemetry .context import Context
9+ from opentelemetry .propagate import extract , inject
10+
11+ from workflows .transport .common_transport import MessageCallback , TemporarySubscription
912
10- from workflows .transport .middleware import BaseTransportMiddleware
11- from workflows .transport .common_transport import TemporarySubscription , MessageCallback
12- import json
1313
1414class OTELTracingMiddleware :
1515 def __init__ (self , tracer : trace .Tracer , service_name : str ):
@@ -19,7 +19,9 @@ def __init__(self, tracer: trace.Tracer, service_name: str):
1919 def send (self , call_next : Callable , destination : str , message : Any , ** kwargs ):
2020 # Get current span context (may be None if this is the root span)
2121 current_span = trace .get_current_span ()
22- parent_context = trace .set_span_in_context (current_span ) if current_span else None
22+ parent_context = (
23+ trace .set_span_in_context (current_span ) if current_span else None
24+ )
2325
2426 with self .tracer .start_as_current_span (
2527 "transport.send" ,
@@ -29,8 +31,7 @@ def send(self, call_next: Callable, destination: str, message: Any, **kwargs):
2931
3032 span .set_attribute ("message" , json .dumps (message ))
3133 span .set_attribute ("destination" , destination )
32- print ("parent_context is..." ,parent_context )
33-
34+ print ("parent_context is..." , parent_context )
3435
3536 # Inject the current trace context into the message headers
3637 headers = kwargs .get ("headers" , {})
@@ -41,7 +42,9 @@ def send(self, call_next: Callable, destination: str, message: Any, **kwargs):
4142
4243 return call_next (destination , message , ** kwargs )
4344
44- def subscribe (self , call_next : Callable , channel : str , callback : Callable , ** kwargs ) -> int :
45+ def subscribe (
46+ self , call_next : Callable , channel : str , callback : Callable , ** kwargs
47+ ) -> int :
4548 @functools .wraps (callback )
4649 def wrapped_callback (header , message ):
4750 # Extract trace context from message headers
@@ -63,13 +66,15 @@ def wrapped_callback(header, message):
6366
6467 return call_next (channel , wrapped_callback , ** kwargs )
6568
66- def subscribe_broadcast (self , call_next : Callable , channel : str , callback : Callable , ** kwargs ) -> int :
69+ def subscribe_broadcast (
70+ self , call_next : Callable , channel : str , callback : Callable , ** kwargs
71+ ) -> int :
6772 @functools .wraps (callback )
6873 def wrapped_callback (header , message ):
6974 # Extract trace context from message headers
7075 ctx = extract (header ) if header else Context ()
7176
72- # # Start a new span with the extracted context
77+ # # Start a new span with the extracted context
7378 with self .tracer .start_as_current_span (
7479 "transport.subscribe_broadcast" ,
7580 context = ctx ,
@@ -119,7 +124,9 @@ def unsubscribe(
119124 ):
120125 # Get current span context
121126 current_span = trace .get_current_span ()
122- current_context = trace .set_span_in_context (current_span ) if current_span else Context ()
127+ current_context = (
128+ trace .set_span_in_context (current_span ) if current_span else Context ()
129+ )
123130
124131 with self .tracer .start_as_current_span (
125132 "transport.unsubscribe" ,
@@ -141,7 +148,9 @@ def ack(
141148 ):
142149 # Get current span context
143150 current_span = trace .get_current_span ()
144- current_context = trace .set_span_in_context (current_span ) if current_span else Context ()
151+ current_context = (
152+ trace .set_span_in_context (current_span ) if current_span else Context ()
153+ )
145154
146155 with self .tracer .start_as_current_span (
147156 "transport.ack" ,
@@ -163,7 +172,9 @@ def nack(
163172 ):
164173 # Get current span context
165174 current_span = trace .get_current_span ()
166- current_context = trace .set_span_in_context (current_span ) if current_span else Context ()
175+ current_context = (
176+ trace .set_span_in_context (current_span ) if current_span else Context ()
177+ )
167178
168179 with self .tracer .start_as_current_span (
169180 "transport.nack" ,
@@ -183,7 +194,9 @@ def transaction_begin(
183194 """Start a new transaction span"""
184195 # Get current span context (may be None if this is the root span)
185196 current_span = trace .get_current_span ()
186- current_context = trace .set_span_in_context (current_span ) if current_span else Context ()
197+ current_context = (
198+ trace .set_span_in_context (current_span ) if current_span else Context ()
199+ )
187200
188201 with self .tracer .start_as_current_span (
189202 "transaction.begin" ,
@@ -196,11 +209,15 @@ def transaction_begin(
196209
197210 return call_next (subscription_id = subscription_id , ** kwargs )
198211
199- def transaction_abort (self , call_next : Callable , transaction_id : int | None = None , ** kwargs ):
212+ def transaction_abort (
213+ self , call_next : Callable , transaction_id : int | None = None , ** kwargs
214+ ):
200215 """Abort a transaction span"""
201216 # Get current span context
202217 current_span = trace .get_current_span ()
203- current_context = trace .set_span_in_context (current_span ) if current_span else Context ()
218+ current_context = (
219+ trace .set_span_in_context (current_span ) if current_span else Context ()
220+ )
204221
205222 with self .tracer .start_as_current_span (
206223 "transaction.abort" ,
@@ -213,11 +230,15 @@ def transaction_abort(self, call_next: Callable, transaction_id: int | None = No
213230
214231 call_next (transaction_id = transaction_id , ** kwargs )
215232
216- def transaction_commit (self , call_next : Callable , transaction_id : int | None = None , ** kwargs ):
233+ def transaction_commit (
234+ self , call_next : Callable , transaction_id : int | None = None , ** kwargs
235+ ):
217236 """Commit a transaction span"""
218237 # Get current span context
219238 current_span = trace .get_current_span ()
220- current_context = trace .set_span_in_context (current_span ) if current_span else Context ()
239+ current_context = (
240+ trace .set_span_in_context (current_span ) if current_span else Context ()
241+ )
221242
222243 with self .tracer .start_as_current_span (
223244 "transaction.commit" ,
@@ -228,4 +249,3 @@ def transaction_commit(self, call_next: Callable, transaction_id: int | None = N
228249 span .set_attribute ("transaction_id" , transaction_id )
229250
230251 call_next (transaction_id = transaction_id , ** kwargs )
231-
0 commit comments