Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, Security, status
from opentelemetry import trace
from opentelemetry.trace import StatusCode
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import JsonValue
from sqlalchemy import and_, func, or_, tuple_, update
from sqlalchemy.engine import CursorResult
Expand All @@ -37,6 +40,7 @@
from sqlalchemy.sql import select
from structlog.contextvars import bind_contextvars

from airflow._shared.observability.traces import override_ids
from airflow._shared.timezones import timezone
from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag
from airflow.api_fastapi.common.db.common import SessionDep
Expand Down Expand Up @@ -87,6 +91,7 @@


log = structlog.get_logger(__name__)
tracer = trace.get_tracer(__name__)


@ti_id_router.patch(
Expand Down Expand Up @@ -431,6 +436,45 @@ def ti_update_state(
)


def _emit_task_span(ti, state):
# just to be safe
if not ti.dag_run:
return
if not isinstance(ti.dag_run.context_carrier, dict):
return
if not isinstance(ti.context_carrier, dict):
return
dr_ctx = TraceContextTextMapPropagator().extract(ti.dag_run.context_carrier)

ti_ctx = TraceContextTextMapPropagator().extract(ti.context_carrier)
ti_span = trace.get_current_span(context=ti_ctx)
span_context = ti_span.get_span_context()
start_time_candidates = (x for x in (ti.queued_dttm, ti.start_date, timezone.utcnow()) if x)
name = f"task_run.{ti.task_id}"
if ti.map_index >= 0:
name += f"_{ti.map_index}"
with override_ids(span_context.trace_id, span_context.span_id):
span = tracer.start_span(
name=name,
start_time=int(min(start_time_candidates).timestamp() * 1e9),
context=dr_ctx,
)

span.set_attributes(
{
"airflow.dag_id": ti.dag_id,
"airflow.task_id": ti.task_id,
"airflow.dag_run.run_id": ti.run_id,
"airflow.task_instance.try_number": ti.try_number,
"airflow.task_instance.map_index": ti.map_index if ti.map_index is not None else -1,
"airflow.task_instance.state": state,
}
)
status_code = StatusCode.OK if state == TaskInstanceState.SUCCESS else StatusCode.ERROR
span.set_status(status_code)
span.end()


def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep, dag_bag: DagBagDep) -> None:
dr = ti.dag_run

Expand Down Expand Up @@ -479,6 +523,7 @@ def _create_ti_state_update_query_and_update_state(
ti_patch_payload.outlet_events,
session,
)
_emit_task_span(ti, state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down
1 change: 0 additions & 1 deletion airflow-core/src/airflow/executors/workloads/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def make(
from airflow.utils.helpers import log_filename_template_renderer

ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True)
ser_ti.context_carrier = ti.dag_run.context_carrier
if not bundle_info:
bundle_info = BundleInfo(
name=ti.dag_model.bundle_name,
Expand Down
85 changes: 61 additions & 24 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import anyio
import attrs
import structlog
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import func, select
from structlog.contextvars import bind_contextvars as bind_log_contextvars
Expand Down Expand Up @@ -87,6 +90,7 @@
from airflow.utils.session import provide_session

if TYPE_CHECKING:
from opentelemetry.util._decorator import _AgnosticContextManager
from sqlalchemy.orm import Session
from structlog.typing import FilteringBoundLogger, WrappedLogger

Expand All @@ -96,6 +100,33 @@
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI

logger = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)


def _make_trigger_span(
ti: TaskInstanceDTO | None, trigger_id: int, name: str
) -> _AgnosticContextManager[trace.Span]:
parent_context = (
TraceContextTextMapPropagator().extract(ti.context_carrier) if ti and ti.context_carrier else None
)
span_name = f"trigger.{ti.task_id}" if ti else f"trigger.{trigger_id}"
if ti and ti.map_index >= 0:
span_name += f"_{ti.map_index}"
attributes: dict[str, str | int] = {
"airflow.trigger.name": name,
}
if ti:
attributes = {
**attributes,
"airflow.dag_id": ti.dag_id,
"airflow.task_id": ti.task_id,
"airflow.dag_run.run_id": ti.run_id,
"airflow.task_instance.try_number": ti.try_number,
"airflow.task_instance.map_index": ti.map_index,
}

return tracer.start_as_current_span(span_name, attributes=attributes, context=parent_context)


__all__ = [
"TriggerRunner",
Expand Down Expand Up @@ -1179,30 +1210,36 @@ async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, timeout_after

name = self.triggers[trigger_id]["name"]
self.log.info("trigger %s starting", name)
try:
async for event in trigger.run():
await self.log.ainfo(
"Trigger fired event", name=self.triggers[trigger_id]["name"], result=event
)
self.triggers[trigger_id]["events"] += 1
self.events.append((trigger_id, event))
except asyncio.CancelledError:
# We get cancelled by the scheduler changing the task state. But if we do lets give a nice error
# message about it
if timeout := timeout_after:
timeout = timeout.replace(tzinfo=timezone.utc) if not timeout.tzinfo else timeout
if timeout < timezone.utcnow():
await self.log.aerror("Trigger cancelled due to timeout")
raise
finally:
# CancelledError will get injected when we're stopped - which is
# fine, the cleanup process will understand that, but we want to
# allow triggers a chance to cleanup, either in that case or if
# they exit cleanly. Exception from cleanup methods are ignored.
with suppress(Exception):
await trigger.cleanup()

await self.log.ainfo("trigger completed", name=name)
with _make_trigger_span(ti=trigger.task_instance, trigger_id=trigger_id, name=name) as span:
try:
async for event in trigger.run():
await self.log.ainfo(
"Trigger fired event", name=self.triggers[trigger_id]["name"], result=event
)
self.triggers[trigger_id]["events"] += 1
self.events.append((trigger_id, event))
span.set_status(Status(StatusCode.OK))
except asyncio.CancelledError as e:
# We get cancelled by the scheduler changing the task state. But if we do lets give a nice error
# message about it
if timeout := timeout_after:
timeout = timeout.replace(tzinfo=timezone.utc) if not timeout.tzinfo else timeout
if timeout < timezone.utcnow():
await self.log.aerror("Trigger cancelled due to timeout")
span.set_status(Status(StatusCode.ERROR), description=str(e))
raise
except Exception as e:
span.set_status(Status(StatusCode.ERROR), description=str(e))
raise
finally:
# CancelledError will get injected when we're stopped - which is
# fine, the cleanup process will understand that, but we want to
# allow triggers a chance to cleanup, either in that case or if
# they exit cleanly. Exception from cleanup methods are ignored.
with suppress(Exception):
await trigger.cleanup()

await self.log.ainfo("trigger completed", name=name)

def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]:
"""
Expand Down
16 changes: 14 additions & 2 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,21 +1019,29 @@ def is_effective_leaf(task):
return leaf_tis

def _emit_dagrun_span(self, state: DagRunState):
# just to be safe
if not isinstance(self.context_carrier, dict):
return

ctx = TraceContextTextMapPropagator().extract(self.context_carrier)
span = trace.get_current_span(context=ctx)
span_context = span.get_span_context()
with override_ids(span_context.trace_id, span_context.span_id):
attributes = {
"airflow.dag_id": str(self.dag_id),
"airflow.dag_run.run_id": self.run_id,
"airflow.dag_run.start_date": self.start_date and str(self.start_date) or None,
"airflow.dag_run.end_date": self.end_date and str(self.end_date) or None,
"airflow.dag_run.queued_at": self.queued_at and str(self.queued_at) or None,
"airflow.dag_run.created_at": self.created_at and str(self.created_at) or None,
}
if self.logical_date:
attributes["airflow.dag_run.logical_date"] = str(self.logical_date)
if self.partition_key:
attributes["airflow.dag_run.partition_key"] = str(self.partition_key)
span = tracer.start_span(
name=f"dag_run.{self.dag_id}",
start_time=int((self.start_date or timezone.utcnow()).timestamp() * 1e9),
start_time=int((self.queued_at or self.start_date or timezone.utcnow()).timestamp() * 1e9),
attributes=attributes,
context=context.Context(),
)
Expand Down Expand Up @@ -1763,7 +1771,11 @@ def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[s
created_counts[task.task_type] += 1
for map_index in indexes:
yield TI.insert_mapping(
self.run_id, task, map_index=map_index, dag_version_id=dag_version_id
self.run_id,
task,
map_index=map_index,
dag_version_id=dag_version_id,
dag_run=self,
)

creator = create_ti_mapping
Expand Down
25 changes: 22 additions & 3 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import attrs
import dill
import uuid6
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from sqlalchemy import (
JSON,
Float,
Expand Down Expand Up @@ -67,6 +69,7 @@
from airflow import settings
from airflow._shared.observability.metrics.dual_stats_manager import DualStatsManager
from airflow._shared.observability.metrics.stats import Stats
from airflow._shared.observability.traces import new_dagrun_trace_carrier
from airflow._shared.timezones import timezone
from airflow.assets.manager import asset_manager
from airflow.configuration import conf
Expand Down Expand Up @@ -102,7 +105,7 @@
TR = TaskReschedule

log = logging.getLogger(__name__)

tracer = trace.get_tracer(__name__)

if TYPE_CHECKING:
from datetime import datetime
Expand Down Expand Up @@ -382,7 +385,7 @@ def clear_task_instances(
for instance in tis:
run_ids_by_dag_id[instance.dag_id].add(instance.run_id)

drs = session.scalars(
drs: Iterable[DagRun] = session.scalars(
select(DagRun).where(
or_(
*(
Expand All @@ -397,6 +400,7 @@ def clear_task_instances(
# Always update clear_number and queued_at when clearing tasks, regardless of state
dr.clear_number += 1
dr.queued_at = timezone.utcnow()
dr.context_carrier = new_dagrun_trace_carrier()

_recalculate_dagrun_queued_at_deadlines(dr, dr.queued_at, session)

Expand Down Expand Up @@ -425,6 +429,8 @@ def clear_task_instances(
if dag_run_state == DagRunState.QUEUED:
dr.last_scheduling_decision = None
dr.start_date = None
for ti in tis:
ti.context_carrier = _make_task_carrier(ti.dag_run.context_carrier)
session.flush()


Expand Down Expand Up @@ -486,6 +492,17 @@ def uuid7() -> UUID:
return uuid6.uuid7()


def _make_task_carrier(dag_run_context_carrier):
parent_context = (
TraceContextTextMapPropagator().extract(dag_run_context_carrier) if dag_run_context_carrier else None
)
span = tracer.start_span("notused", context=parent_context) # intentionally never closed
new_ctx = trace.set_span_in_context(span)
carrier: dict[str, str] = {}
TraceContextTextMapPropagator().inject(carrier, context=new_ctx)
return carrier


class TaskInstance(Base, LoggingMixin, BaseWorkload):
"""
Task instances store the state of a task instance.
Expand Down Expand Up @@ -679,7 +696,7 @@ def stats_tags(self) -> dict[str, str]:

@staticmethod
def insert_mapping(
run_id: str, task: Operator, map_index: int, dag_version_id: UUID | None
run_id: str, task: Operator, map_index: int, *, dag_version_id: UUID | None, dag_run: DagRun
) -> dict[str, Any]:
"""
Insert mapping.
Expand All @@ -689,6 +706,7 @@ def insert_mapping(
priority_weight = task.weight_rule.get_weight(
TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id)
)
context_carrier = _make_task_carrier(dag_run.context_carrier)

return {
"dag_id": task.dag_id,
Expand All @@ -710,6 +728,7 @@ def insert_mapping(
"map_index": map_index,
"_task_display_property_value": task.task_display_name,
"dag_version_id": dag_version_id,
"context_carrier": context_carrier,
}

@reconstructor
Expand Down
25 changes: 25 additions & 0 deletions airflow-core/src/airflow/models/taskmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
from collections.abc import Collection, Iterable, Sequence
from typing import TYPE_CHECKING, Any

from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String, func, or_, select
from sqlalchemy.orm import Mapped, mapped_column

from airflow.models import DagRun
from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
from airflow.models.dag_version import DagVersion
from airflow.utils.db import exists_query
Expand All @@ -38,6 +41,7 @@

from airflow.models.taskinstance import TaskInstance
from airflow.serialization.definitions.mappedoperator import Operator
tracer = trace.get_tracer(__name__)


class TaskMapVariant(enum.Enum):
Expand All @@ -52,6 +56,17 @@ class TaskMapVariant(enum.Enum):
LIST = "list"


def _make_task_carrier(dag_run_context_carrier):
parent_context = (
TraceContextTextMapPropagator().extract(dag_run_context_carrier) if dag_run_context_carrier else None
)
span = tracer.start_span("notused", context=parent_context) # intentionally never closed
new_ctx = trace.set_span_in_context(span)
carrier: dict[str, str] = {}
TraceContextTextMapPropagator().inject(carrier, context=new_ctx)
return carrier


class TaskMap(TaskInstanceDependencies):
"""
Model to track dynamic task-mapping information.
Expand Down Expand Up @@ -254,6 +269,16 @@ def expand_mapped_task(
task.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
if unmapped_ti:
dr = unmapped_ti.dag_run
else:
dr = session.scalar(
select(DagRun).where(
DagRun.dag_id == task.dag_id,
DagRun.run_id == run_id,
)
)
ti.context_carrier = _make_task_carrier(dr.context_carrier)
ti.refresh_from_task(task) # session.merge() loses task information.
all_expanded_tis.append(ti)

Expand Down
Loading
Loading