diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 60dd868c2e335..5bc9bb748cc56 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -29,6 +29,9 @@ import structlog from cadwyn import VersionedAPIRouter from fastapi import Body, HTTPException, Query, Security, status +from opentelemetry import trace +from opentelemetry.trace import StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import JsonValue from sqlalchemy import and_, func, or_, tuple_, update from sqlalchemy.engine import CursorResult @@ -37,6 +40,7 @@ from sqlalchemy.sql import select from structlog.contextvars import bind_contextvars +from airflow._shared.observability.traces import override_ids from airflow._shared.timezones import timezone from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag from airflow.api_fastapi.common.db.common import SessionDep @@ -87,6 +91,7 @@ log = structlog.get_logger(__name__) +tracer = trace.get_tracer(__name__) @ti_id_router.patch( @@ -431,6 +436,45 @@ def ti_update_state( ) +def _emit_task_span(ti, state): + # just to be safe + if not ti.dag_run: + return + if not isinstance(ti.dag_run.context_carrier, dict): + return + if not isinstance(ti.context_carrier, dict): + return + dr_ctx = TraceContextTextMapPropagator().extract(ti.dag_run.context_carrier) + + ti_ctx = TraceContextTextMapPropagator().extract(ti.context_carrier) + ti_span = trace.get_current_span(context=ti_ctx) + span_context = ti_span.get_span_context() + start_time_candidates = (x for x in (ti.queued_dttm, ti.start_date, timezone.utcnow()) if x) + name = f"task_run.{ti.task_id}" + if ti.map_index >= 0: + name += f"_{ti.map_index}" + with override_ids(span_context.trace_id, span_context.span_id): + span = tracer.start_span( + name=name, + start_time=int(min(start_time_candidates).timestamp() * 1e9), + context=dr_ctx, + ) + + span.set_attributes( + { + "airflow.dag_id": ti.dag_id, + "airflow.task_id": ti.task_id, + "airflow.dag_run.run_id": ti.run_id, + "airflow.task_instance.try_number": ti.try_number, + "airflow.task_instance.map_index": ti.map_index if ti.map_index is not None else -1, + "airflow.task_instance.state": state, + } + ) + status_code = StatusCode.OK if state == TaskInstanceState.SUCCESS else StatusCode.ERROR + span.set_status(status_code) + span.end() + + def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep, dag_bag: DagBagDep) -> None: dr = ti.dag_run @@ -479,6 +523,7 @@ def _create_ti_state_update_query_and_update_state( ti_patch_payload.outlet_events, session, ) + _emit_task_span(ti, state=updated_state) elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed timeout = None diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index a5939cf424412..4ca8c310fb5c2 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -86,7 +86,6 @@ def make( from airflow.utils.helpers import log_filename_template_renderer ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True) - ser_ti.context_carrier = ti.dag_run.context_carrier if not bundle_info: bundle_info = BundleInfo( name=ti.dag_model.bundle_name, diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 1406283c05cb3..9be1da1168607 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -35,6 +35,9 @@ import anyio import attrs import structlog +from opentelemetry import trace +from opentelemetry.trace import Status, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import BaseModel, Field, TypeAdapter from sqlalchemy import func, select from structlog.contextvars import bind_contextvars as bind_log_contextvars @@ -87,6 +90,7 @@ from airflow.utils.session import provide_session if TYPE_CHECKING: + from opentelemetry.util._decorator import _AgnosticContextManager from sqlalchemy.orm import Session from structlog.typing import FilteringBoundLogger, WrappedLogger @@ -96,6 +100,33 @@ from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI logger = logging.getLogger(__name__) +tracer = trace.get_tracer(__name__) + + +def _make_trigger_span( + ti: TaskInstanceDTO | None, trigger_id: int, name: str +) -> _AgnosticContextManager[trace.Span]: + parent_context = ( + TraceContextTextMapPropagator().extract(ti.context_carrier) if ti and ti.context_carrier else None + ) + span_name = f"trigger.{ti.task_id}" if ti else f"trigger.{trigger_id}" + if ti and ti.map_index >= 0: + span_name += f"_{ti.map_index}" + attributes: dict[str, str | int] = { + "airflow.trigger.name": name, + } + if ti: + attributes = { + **attributes, + "airflow.dag_id": ti.dag_id, + "airflow.task_id": ti.task_id, + "airflow.dag_run.run_id": ti.run_id, + "airflow.task_instance.try_number": ti.try_number, + "airflow.task_instance.map_index": ti.map_index, + } + + return tracer.start_as_current_span(span_name, attributes=attributes, context=parent_context) + __all__ = [ "TriggerRunner", @@ -1179,30 +1210,36 @@ async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, timeout_after name = self.triggers[trigger_id]["name"] self.log.info("trigger %s starting", name) - try: - async for event in trigger.run(): - await self.log.ainfo( - "Trigger fired event", name=self.triggers[trigger_id]["name"], result=event - ) - self.triggers[trigger_id]["events"] += 1 - self.events.append((trigger_id, event)) - except asyncio.CancelledError: - # We get cancelled by the scheduler changing the task state. But if we do lets give a nice error - # message about it - if timeout := timeout_after: - timeout = timeout.replace(tzinfo=timezone.utc) if not timeout.tzinfo else timeout - if timeout < timezone.utcnow(): - await self.log.aerror("Trigger cancelled due to timeout") - raise - finally: - # CancelledError will get injected when we're stopped - which is - # fine, the cleanup process will understand that, but we want to - # allow triggers a chance to cleanup, either in that case or if - # they exit cleanly. Exception from cleanup methods are ignored. - with suppress(Exception): - await trigger.cleanup() - - await self.log.ainfo("trigger completed", name=name) + with _make_trigger_span(ti=trigger.task_instance, trigger_id=trigger_id, name=name) as span: + try: + async for event in trigger.run(): + await self.log.ainfo( + "Trigger fired event", name=self.triggers[trigger_id]["name"], result=event + ) + self.triggers[trigger_id]["events"] += 1 + self.events.append((trigger_id, event)) + span.set_status(Status(StatusCode.OK)) + except asyncio.CancelledError as e: + # We get cancelled by the scheduler changing the task state. But if we do lets give a nice error + # message about it + if timeout := timeout_after: + timeout = timeout.replace(tzinfo=timezone.utc) if not timeout.tzinfo else timeout + if timeout < timezone.utcnow(): + await self.log.aerror("Trigger cancelled due to timeout") + span.set_status(Status(StatusCode.ERROR), description=str(e)) + raise + except Exception as e: + span.set_status(Status(StatusCode.ERROR), description=str(e)) + raise + finally: + # CancelledError will get injected when we're stopped - which is + # fine, the cleanup process will understand that, but we want to + # allow triggers a chance to cleanup, either in that case or if + # they exit cleanly. Exception from cleanup methods are ignored. + with suppress(Exception): + await trigger.cleanup() + + await self.log.ainfo("trigger completed", name=name) def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]: """ diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 23fdfb7256453..6bbeaef19f227 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1019,6 +1019,10 @@ def is_effective_leaf(task): return leaf_tis def _emit_dagrun_span(self, state: DagRunState): + # just to be safe + if not isinstance(self.context_carrier, dict): + return + ctx = TraceContextTextMapPropagator().extract(self.context_carrier) span = trace.get_current_span(context=ctx) span_context = span.get_span_context() @@ -1026,6 +1030,10 @@ def _emit_dagrun_span(self, state: DagRunState): attributes = { "airflow.dag_id": str(self.dag_id), "airflow.dag_run.run_id": self.run_id, + "airflow.dag_run.start_date": self.start_date and str(self.start_date) or None, + "airflow.dag_run.end_date": self.end_date and str(self.end_date) or None, + "airflow.dag_run.queued_at": self.queued_at and str(self.queued_at) or None, + "airflow.dag_run.created_at": self.created_at and str(self.created_at) or None, } if self.logical_date: attributes["airflow.dag_run.logical_date"] = str(self.logical_date) @@ -1033,7 +1041,7 @@ def _emit_dagrun_span(self, state: DagRunState): attributes["airflow.dag_run.partition_key"] = str(self.partition_key) span = tracer.start_span( name=f"dag_run.{self.dag_id}", - start_time=int((self.start_date or timezone.utcnow()).timestamp() * 1e9), + start_time=int((self.queued_at or self.start_date or timezone.utcnow()).timestamp() * 1e9), attributes=attributes, context=context.Context(), ) @@ -1763,7 +1771,11 @@ def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[s created_counts[task.task_type] += 1 for map_index in indexes: yield TI.insert_mapping( - self.run_id, task, map_index=map_index, dag_version_id=dag_version_id + self.run_id, + task, + map_index=map_index, + dag_version_id=dag_version_id, + dag_run=self, ) creator = create_ti_mapping diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index c8a9c05bd9fd3..2e6e44eb3c7a0 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -32,6 +32,8 @@ import attrs import dill import uuid6 +from opentelemetry import trace +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from sqlalchemy import ( JSON, Float, @@ -67,6 +69,7 @@ from airflow import settings from airflow._shared.observability.metrics.dual_stats_manager import DualStatsManager from airflow._shared.observability.metrics.stats import Stats +from airflow._shared.observability.traces import new_dagrun_trace_carrier from airflow._shared.timezones import timezone from airflow.assets.manager import asset_manager from airflow.configuration import conf @@ -102,7 +105,7 @@ TR = TaskReschedule log = logging.getLogger(__name__) - +tracer = trace.get_tracer(__name__) if TYPE_CHECKING: from datetime import datetime @@ -382,7 +385,7 @@ def clear_task_instances( for instance in tis: run_ids_by_dag_id[instance.dag_id].add(instance.run_id) - drs = session.scalars( + drs: Iterable[DagRun] = session.scalars( select(DagRun).where( or_( *( @@ -397,6 +400,7 @@ def clear_task_instances( # Always update clear_number and queued_at when clearing tasks, regardless of state dr.clear_number += 1 dr.queued_at = timezone.utcnow() + dr.context_carrier = new_dagrun_trace_carrier() _recalculate_dagrun_queued_at_deadlines(dr, dr.queued_at, session) @@ -425,6 +429,8 @@ def clear_task_instances( if dag_run_state == DagRunState.QUEUED: dr.last_scheduling_decision = None dr.start_date = None + for ti in tis: + ti.context_carrier = _make_task_carrier(ti.dag_run.context_carrier) session.flush() @@ -486,6 +492,17 @@ def uuid7() -> UUID: return uuid6.uuid7() +def _make_task_carrier(dag_run_context_carrier): + parent_context = ( + TraceContextTextMapPropagator().extract(dag_run_context_carrier) if dag_run_context_carrier else None + ) + span = tracer.start_span("notused", context=parent_context) # intentionally never closed + new_ctx = trace.set_span_in_context(span) + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier, context=new_ctx) + return carrier + + class TaskInstance(Base, LoggingMixin, BaseWorkload): """ Task instances store the state of a task instance. @@ -679,7 +696,7 @@ def stats_tags(self) -> dict[str, str]: @staticmethod def insert_mapping( - run_id: str, task: Operator, map_index: int, dag_version_id: UUID | None + run_id: str, task: Operator, map_index: int, *, dag_version_id: UUID | None, dag_run: DagRun ) -> dict[str, Any]: """ Insert mapping. @@ -689,6 +706,7 @@ def insert_mapping( priority_weight = task.weight_rule.get_weight( TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id) ) + context_carrier = _make_task_carrier(dag_run.context_carrier) return { "dag_id": task.dag_id, @@ -710,6 +728,7 @@ def insert_mapping( "map_index": map_index, "_task_display_property_value": task.task_display_name, "dag_version_id": dag_version_id, + "context_carrier": context_carrier, } @reconstructor diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index 18d09d6aa564f..1b160b381a852 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -24,9 +24,12 @@ from collections.abc import Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any +from opentelemetry import trace +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String, func, or_, select from sqlalchemy.orm import Mapped, mapped_column +from airflow.models import DagRun from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies from airflow.models.dag_version import DagVersion from airflow.utils.db import exists_query @@ -38,6 +41,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.serialization.definitions.mappedoperator import Operator +tracer = trace.get_tracer(__name__) class TaskMapVariant(enum.Enum): @@ -52,6 +56,17 @@ class TaskMapVariant(enum.Enum): LIST = "list" +def _make_task_carrier(dag_run_context_carrier): + parent_context = ( + TraceContextTextMapPropagator().extract(dag_run_context_carrier) if dag_run_context_carrier else None + ) + span = tracer.start_span("notused", context=parent_context) # intentionally never closed + new_ctx = trace.set_span_in_context(span) + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier, context=new_ctx) + return carrier + + class TaskMap(TaskInstanceDependencies): """ Model to track dynamic task-mapping information. @@ -254,6 +269,16 @@ def expand_mapped_task( task.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) ti = session.merge(ti) + if unmapped_ti: + dr = unmapped_ti.dag_run + else: + dr = session.scalar( + select(DagRun).where( + DagRun.dag_id == task.dag_id, + DagRun.run_id == run_id, + ) + ) + ti.context_carrier = _make_task_carrier(dr.context_carrier) ti.refresh_from_task(task) # session.merge() loses task information. all_expanded_tis.append(ti) diff --git a/airflow-core/tests/integration/otel/test_otel.py b/airflow-core/tests/integration/otel/test_otel.py index c9bd4ad8b789d..d042f5b7927ed 100644 --- a/airflow-core/tests/integration/otel/test_otel.py +++ b/airflow-core/tests/integration/otel/test_otel.py @@ -503,9 +503,10 @@ def get_parent_span_id(span): nested = get_span_hierarchy() assert nested == { - "sub_span1": "task_run.task1", - "task_run.task1": "dag_run.otel_test_dag", "dag_run.otel_test_dag": None, + "sub_span1": "worker.task1", + "task_run.task1": "dag_run.otel_test_dag", + "worker.task1": "task_run.task1", } def start_scheduler(self, capture_output: bool = False): diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 835a8c46139f7..44e4f35702b06 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -3153,3 +3153,145 @@ def test_no_scope_defaults_to_execution(self, client, session, create_task_insta payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) assert resp.status_code in [200, 204] + + +class TestEmitTaskSpan: + """Tests for the _emit_task_span function in the execution API task-instance route.""" + + @pytest.fixture(autouse=True) + def sdk_tracer_provider(self): + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + from airflow._shared.observability.traces import OverrideableRandomIdGenerator + + self.exporter = InMemorySpanExporter() + provider = TracerProvider(id_generator=OverrideableRandomIdGenerator()) + provider.add_span_processor(SimpleSpanProcessor(self.exporter)) + test_tracer = provider.get_tracer("test") + with mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.tracer", test_tracer): + yield + + def _make_carriers(self): + """Return a (dr_carrier, ti_carrier) pair built with a real SDK provider.""" + from opentelemetry import trace as otel_trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + p = TracerProvider() + t = p.get_tracer("setup") + dr_span = t.start_span("dr") + dr_ctx = otel_trace.set_span_in_context(dr_span) + dr_carrier: dict = {} + TraceContextTextMapPropagator().inject(dr_carrier, context=dr_ctx) + ti_span = t.start_span("ti", context=dr_ctx) + ti_ctx = otel_trace.set_span_in_context(ti_span) + ti_carrier: dict = {} + TraceContextTextMapPropagator().inject(ti_carrier, context=ti_ctx) + return dr_carrier, ti_carrier + + def _make_ti(self, task_id="my_task", map_index=-1, queued_dttm=None, start_date=None): + from unittest.mock import MagicMock + + dr_carrier, ti_carrier = self._make_carriers() + ti = MagicMock() + ti.dag_id = "test_dag" + ti.task_id = task_id + ti.run_id = "test_run" + ti.try_number = 1 + ti.map_index = map_index + ti.queued_dttm = queued_dttm + ti.start_date = start_date or DEFAULT_START_DATE + ti.dag_run.context_carrier = dr_carrier + ti.context_carrier = ti_carrier + return ti + + def test_emit_task_span_success_sets_ok_status(self): + from opentelemetry.trace import StatusCode + + from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span + + _emit_task_span(self._make_ti(), TaskInstanceState.SUCCESS) + + spans = self.exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.OK + + def test_emit_task_span_failed_sets_error_status(self): + from opentelemetry.trace import StatusCode + + from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span + + _emit_task_span(self._make_ti(), TaskInstanceState.FAILED) + + spans = self.exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + + def test_emit_task_span_sets_attributes(self): + from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span + + ti = self._make_ti(task_id="my_task", map_index=2) + _emit_task_span(ti, TaskInstanceState.SUCCESS) + + attrs = self.exporter.get_finished_spans()[0].attributes + assert attrs["airflow.dag_id"] == "test_dag" + assert attrs["airflow.task_id"] == "my_task" + assert attrs["airflow.dag_run.run_id"] == "test_run" + assert attrs["airflow.task_instance.try_number"] == 1 + assert attrs["airflow.task_instance.map_index"] == 2 + assert attrs["airflow.task_instance.state"] == TaskInstanceState.SUCCESS + + def test_emit_task_span_name_unmapped(self): + from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span + + _emit_task_span(self._make_ti(task_id="my_task", map_index=-1), TaskInstanceState.SUCCESS) + assert self.exporter.get_finished_spans()[0].name == "task_run.my_task" + + def test_emit_task_span_name_mapped(self): + from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span + + _emit_task_span(self._make_ti(task_id="my_task", map_index=3), TaskInstanceState.SUCCESS) + assert self.exporter.get_finished_spans()[0].name == "task_run.my_task_3" + + def test_emit_task_span_start_time_uses_queued_dttm(self): + from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span + + queued_dttm = timezone.parse("2024-01-01T10:00:00Z") + start_date = timezone.parse("2024-01-01T10:05:00Z") + ti = self._make_ti(queued_dttm=queued_dttm, start_date=start_date) + _emit_task_span(ti, TaskInstanceState.SUCCESS) + + assert self.exporter.get_finished_spans()[0].start_time == int(queued_dttm.timestamp() * 1e9) + + def test_emit_task_span_start_time_falls_back_to_start_date(self): + from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span + + start_date = timezone.parse("2024-01-01T10:05:00Z") + ti = self._make_ti(queued_dttm=None, start_date=start_date) + _emit_task_span(ti, TaskInstanceState.SUCCESS) + + assert self.exporter.get_finished_spans()[0].start_time == int(start_date.timestamp() * 1e9) + + def test_emit_task_span_skips_if_no_ti_carrier(self): + from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span + + ti = mock.MagicMock() + ti.dag_run.context_carrier = { + "traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" + } + ti.context_carrier = None + + _emit_task_span(ti, TaskInstanceState.SUCCESS) + assert len(self.exporter.get_finished_spans()) == 0 + + def test_emit_task_span_skips_if_no_dagrun_carrier(self): + from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span + + ti = mock.MagicMock() + ti.dag_run.context_carrier = None + ti.context_carrier = {"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"} + + _emit_task_span(ti, TaskInstanceState.SUCCESS) + assert len(self.exporter.get_finished_spans()) == 0 diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 3d77880427929..c16d5acbb2a52 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -27,6 +27,7 @@ from collections.abc import AsyncIterator from socket import socket from typing import TYPE_CHECKING, Any +from unittest import mock from unittest.mock import ANY, AsyncMock, MagicMock, patch import pendulum @@ -45,6 +46,7 @@ TriggerLoggingFactory, TriggerRunner, TriggerRunnerSupervisor, + _make_trigger_span, messages, ) from airflow.models import Connection, DagModel, DagRun, Trigger, Variable @@ -349,11 +351,12 @@ def test_run_inline_trigger_canceled(self, session) -> None: mock_trigger = MagicMock(spec=BaseTrigger) mock_trigger.timeout_after = None mock_trigger.run.side_effect = asyncio.CancelledError() + mock_trigger.task_instance = MagicMock() + mock_trigger.task_instance.map_index = -1 with pytest.raises(asyncio.CancelledError): asyncio.run(trigger_runner.run_trigger(1, mock_trigger)) - # @pytest.mark.asyncio def test_run_inline_trigger_timeout(self, session, cap_structlog) -> None: trigger_runner = TriggerRunner() trigger_runner.triggers = { @@ -361,6 +364,8 @@ def test_run_inline_trigger_timeout(self, session, cap_structlog) -> None: } mock_trigger = MagicMock(spec=BaseTrigger) mock_trigger.run.side_effect = asyncio.CancelledError() + mock_trigger.task_instance = MagicMock() + mock_trigger.task_instance.map_index = -1 with pytest.raises(asyncio.CancelledError): asyncio.run( @@ -1356,3 +1361,101 @@ def get_type_names(union_type): + "\n".join(f" - {t}" for t in sorted(task_diff)) + "\n\nEither handle these types in ToTriggerRunner or update in_task_but_not_in_trigger_runner list." ) + + +class TestMakeTriggerSpan: + """Tests for the _make_trigger_span helper in the triggerer job runner.""" + + @pytest.fixture(autouse=True) + def sdk_tracer_provider(self): + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + self.exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(self.exporter)) + test_tracer = provider.get_tracer("test") + with mock.patch("airflow.jobs.triggerer_job_runner.tracer", test_tracer): + yield + + def _make_ti_dto(self, task_id="my_task", map_index=-1, context_carrier=None): + import uuid + + from airflow.executors.workloads.task import TaskInstanceDTO + + return TaskInstanceDTO( + id=uuid.uuid4(), + dag_version_id=uuid.uuid4(), + task_id=task_id, + dag_id="test_dag", + run_id="test_run", + try_number=1, + map_index=map_index, + pool_slots=1, + queue="default", + priority_weight=1, + context_carrier=context_carrier, + ) + + def test_make_trigger_span_name_with_task_instance(self): + ti = self._make_ti_dto(task_id="sensor_task", map_index=-1) + with _make_trigger_span(ti=ti, trigger_id=1, name="MySensor"): + pass + assert self.exporter.get_finished_spans()[0].name == "trigger.sensor_task" + + def test_make_trigger_span_name_with_mapped_task(self): + ti = self._make_ti_dto(task_id="sensor_task", map_index=2) + with _make_trigger_span(ti=ti, trigger_id=1, name="MySensor"): + pass + assert self.exporter.get_finished_spans()[0].name == "trigger.sensor_task_2" + + def test_make_trigger_span_name_without_task_instance(self): + with _make_trigger_span(ti=None, trigger_id=42, name="MySensor"): + pass + assert self.exporter.get_finished_spans()[0].name == "trigger.42" + + def test_make_trigger_span_uses_task_context_carrier(self): + from opentelemetry import trace as otel_trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + # Build a valid ti carrier from a separate provider so we have a known parent span. + setup_provider = TracerProvider() + setup_tracer = setup_provider.get_tracer("setup") + parent_span = setup_tracer.start_span("ti_parent") + parent_ctx = otel_trace.set_span_in_context(parent_span) + ti_carrier: dict = {} + TraceContextTextMapPropagator().inject(ti_carrier, context=parent_ctx) + expected_parent_span_id = parent_span.get_span_context().span_id + + ti = self._make_ti_dto(context_carrier=ti_carrier) + with _make_trigger_span(ti=ti, trigger_id=1, name="MySensor"): + pass + + spans = self.exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].parent is not None + assert spans[0].parent.span_id == expected_parent_span_id + + def test_make_trigger_span_sets_attributes_with_ti(self): + ti = self._make_ti_dto(task_id="my_task", map_index=1) + with _make_trigger_span(ti=ti, trigger_id=5, name="MyTrigger"): + pass + + attrs = self.exporter.get_finished_spans()[0].attributes + assert attrs["airflow.trigger.name"] == "MyTrigger" + assert attrs["airflow.dag_id"] == "test_dag" + assert attrs["airflow.task_id"] == "my_task" + assert attrs["airflow.dag_run.run_id"] == "test_run" + assert attrs["airflow.task_instance.try_number"] == 1 + assert attrs["airflow.task_instance.map_index"] == 1 + + def test_make_trigger_span_sets_only_trigger_name_without_ti(self): + with _make_trigger_span(ti=None, trigger_id=99, name="OnlyTrigger"): + pass + + attrs = self.exporter.get_finished_spans()[0].attributes + assert attrs["airflow.trigger.name"] == "OnlyTrigger" + assert "airflow.dag_id" not in attrs + assert "airflow.task_id" not in attrs diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 82b9adc3162ae..eb13db6480a71 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -59,6 +59,7 @@ TaskInstance, TaskInstance as TI, TaskInstanceNote, + _make_task_carrier, clear_task_instances, find_relevant_relatives, ) @@ -3298,3 +3299,109 @@ def test_get_dagrun_loaded_but_none_returns_dagrun(dag_maker, session): assert dr_from_ti is not None assert dr_from_ti == dr + + +class TestMakeTaskCarrier: + """Tests for the _make_task_carrier helper.""" + + @pytest.fixture(autouse=True) + def sdk_tracer_provider(self): + from opentelemetry.sdk.trace import TracerProvider + + provider = TracerProvider() + real_tracer = provider.get_tracer("airflow.models.taskinstance") + with mock.patch("airflow.models.taskinstance.tracer", real_tracer): + yield + + def test_make_task_carrier_returns_traceparent(self): + from airflow._shared.observability.traces import new_dagrun_trace_carrier + + carrier = _make_task_carrier(new_dagrun_trace_carrier()) + assert isinstance(carrier, dict) + assert "traceparent" in carrier + + def test_make_task_carrier_child_of_parent(self): + from opentelemetry import trace as otel_trace + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + from airflow._shared.observability.traces import new_dagrun_trace_carrier + + parent_carrier = new_dagrun_trace_carrier() + child_carrier = _make_task_carrier(parent_carrier) + + propagator = TraceContextTextMapPropagator() + parent_trace_id = ( + otel_trace.get_current_span(context=propagator.extract(parent_carrier)) + .get_span_context() + .trace_id + ) + child_trace_id = ( + otel_trace.get_current_span(context=propagator.extract(child_carrier)).get_span_context().trace_id + ) + assert child_trace_id == parent_trace_id + assert child_trace_id != 0 + + def test_make_task_carrier_with_none_carrier(self): + carrier = _make_task_carrier(None) + assert isinstance(carrier, dict) + assert "traceparent" in carrier + + +@pytest.mark.db_test +def test_insert_mapping_includes_context_carrier(dag_maker, session): + """insert_mapping should include a context_carrier with a traceparent derived from the dag run.""" + from opentelemetry.sdk.trace import TracerProvider + + from airflow._shared.observability.traces import new_dagrun_trace_carrier + + provider = TracerProvider() + real_tracer = provider.get_tracer("airflow.models.taskinstance") + with mock.patch("airflow.models.taskinstance.tracer", real_tracer): + with dag_maker("test_insert_mapping_carrier"): + EmptyOperator(task_id="t1") + session.flush() + + # Get the scheduler-side operator (has a proper PriorityWeightStrategy, not the enum weight_rule). + op = create_scheduler_operator(dag_maker.dag.get_task("t1")) + + # Mock the DagRun to avoid inserting into the dag_run table (schema migrations may be pending). + dag_run = mock.MagicMock() + dag_run.context_carrier = new_dagrun_trace_carrier() + + mapping = TaskInstance.insert_mapping( + run_id="test_run", + task=op, + map_index=0, + dag_version_id=None, + dag_run=dag_run, + ) + + assert "context_carrier" in mapping + assert mapping["context_carrier"] is not None + assert "traceparent" in mapping["context_carrier"] + + +@pytest.mark.db_test +def test_clear_task_instances_resets_context_carrier(dag_maker, session): + """clear_task_instances should assign fresh context carriers to both the TI and its dag run.""" + from opentelemetry.sdk.trace import TracerProvider + + provider = TracerProvider() + real_tracer = provider.get_tracer("airflow.models.taskinstance") + with mock.patch("airflow.models.taskinstance.tracer", real_tracer): + with dag_maker("test_clear_carrier"): + EmptyOperator(task_id="t1") + dag_run = dag_maker.create_dagrun() + ti = dag_run.get_task_instance("t1", session=session) + ti.state = TaskInstanceState.SUCCESS + # Set an explicit carrier so we can verify it changes. + ti.context_carrier = {"traceparent": "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaa0001-bbbbbbbbbbbbbbbb-01"} + session.flush() + + original_ti_traceparent = ti.context_carrier["traceparent"] + original_dr_traceparent = dag_run.context_carrier["traceparent"] + + clear_task_instances([ti], session) + + assert ti.context_carrier["traceparent"] != original_ti_traceparent + assert dag_run.context_carrier["traceparent"] != original_dr_traceparent diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 611c4fc28ec19..6782a8b992880 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -146,7 +146,7 @@ def _make_task_span(msg: StartupDetails): TraceContextTextMapPropagator().extract(msg.ti.context_carrier) if msg.ti.context_carrier else None ) ti = msg.ti - span_name = f"task_run.{ti.task_id}" + span_name = f"worker.{ti.task_id}" if ti.map_index is not None and ti.map_index >= 0: span_name += f"_{ti.map_index}" with tracer.start_as_current_span(span_name, context=parent_context) as span: