-
Notifications
You must be signed in to change notification settings - Fork 16.8k
Re-enable start_from_trigger feature with rendering of template fields #55068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
90fa8fd
5f7306d
2fddf07
a21201f
937a379
6819e63
f5dd331
ed0594b
ec34481
df75780
23d7aea
6684575
6c28707
96d1893
4ab221e
1da2ca1
0b4a36a
67bfa1e
bb80bfc
558c88a
d5d339f
c7ce525
f631b37
499f463
259ae7b
a7e7c69
a019832
b18e942
2dc4f4b
337b22f
c2de568
02e3413
064e59e
9010f1c
a3a9964
af9553d
40b9b2d
5042c3c
3e92f47
f3ca48a
85ac14d
3041c11
56f8cec
cc9b2cc
f3ffb51
c2c98b7
1a85868
76ee5a6
2834afc
d19f690
399c94b
6790e65
5aace8f
5331992
39f0782
75dfc99
1436811
97b4d45
bac2d24
4796335
8e38dcf
19f975b
869ba37
1d0c6fd
2684561
c2f8271
8717575
f7fa0b0
cd23b0c
66b94f5
f5e3289
2bfafa1
0f564af
e30f00e
e22575a
f2decd4
eff4fbd
129e390
8eb0042
7e5c062
20f7ac0
856c496
398265e
e2ebeeb
169e5d9
1b856ae
55ed71d
68a3d70
937f808
0ad75bf
3d7c35e
fbaec4e
ade3a17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ | |
| import sys | ||
| import time | ||
| from collections import deque | ||
| from collections.abc import Generator, Iterable | ||
| from collections.abc import Callable, Generator, Iterable | ||
| from contextlib import suppress | ||
| from datetime import datetime | ||
| from socket import socket | ||
|
|
@@ -51,6 +51,7 @@ | |
| from airflow.executors.workloads.task import TaskInstanceDTO | ||
| from airflow.jobs.base_job_runner import BaseJobRunner | ||
| from airflow.jobs.job import perform_heartbeat | ||
| from airflow.models.dagbag import DBDagBag | ||
| from airflow.models.trigger import Trigger | ||
| from airflow.observability.metrics import stats_utils | ||
| from airflow.sdk.api.datamodels._generated import HITLDetailResponse | ||
|
|
@@ -84,10 +85,12 @@ | |
| _RequestFrame, | ||
| ) | ||
| from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader | ||
| from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance | ||
| from airflow.serialization.serialized_objects import DagSerialization | ||
| from airflow.triggers.base import BaseEventTrigger, BaseTrigger, DiscrimatedTriggerEvent, TriggerEvent | ||
| from airflow.utils.helpers import log_filename_template_renderer | ||
| from airflow.utils.log.logging_mixin import LoggingMixin | ||
| from airflow.utils.session import provide_session | ||
| from airflow.utils.session import create_session, provide_session | ||
|
|
||
| if TYPE_CHECKING: | ||
| from opentelemetry.util._decorator import _AgnosticContextManager | ||
|
|
@@ -97,6 +100,7 @@ | |
| from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI | ||
| from airflow.jobs.job import Job | ||
| from airflow.sdk.api.client import Client | ||
| from airflow.sdk.definitions.context import Context | ||
| from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -658,6 +662,65 @@ def emit_metrics(self): | |
| extra_tags={"hostname": self.job.hostname}, | ||
| ) | ||
|
|
||
| def _create_workload( | ||
| self, | ||
| trigger: Trigger, | ||
| dag_bag: DBDagBag, | ||
| render_log_fname: Callable[..., str], | ||
| session: Session, | ||
| ) -> workloads.RunTrigger | None: | ||
| if trigger.task_instance is None: | ||
| return workloads.RunTrigger( | ||
| id=trigger.id, | ||
| classpath=trigger.classpath, | ||
| encrypted_kwargs=trigger.encrypted_kwargs, | ||
| ) | ||
|
|
||
| if not trigger.task_instance.dag_version_id: | ||
| # This is to handle 2 to 3 upgrade where TI.dag_version_id can be none | ||
| log.warning( | ||
| "TaskInstance associated with Trigger has no associated Dag Version, skipping the trigger", | ||
| ti_id=trigger.task_instance.id, | ||
| ) | ||
| return None | ||
|
|
||
| log_path = render_log_fname(ti=trigger.task_instance) | ||
| ser_ti = TaskInstanceDTO.model_validate(trigger.task_instance, from_attributes=True) | ||
|
|
||
| # When producing logs from TIs, include the job id producing the logs to disambiguate it. | ||
| self.logger_cache[trigger.id] = TriggerLoggingFactory( | ||
| log_path=f"{log_path}.trigger.{self.job.id}.log", | ||
| ti=ser_ti, # type: ignore | ||
| ) | ||
|
|
||
| serialized_dag_model = dag_bag.get_serialized_dag_model( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider adding a lightweight indicator (e.g., a boolean flag on the Trigger model or TI) so you can skip the DAG load entirely when
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That will be done in yet another PR. |
||
| version_id=trigger.task_instance.dag_version_id, | ||
| session=session, | ||
| ) | ||
|
|
||
| if serialized_dag_model: | ||
| task = serialized_dag_model.dag.get_task(trigger.task_instance.task_id) | ||
|
|
||
| # When a TaskInstance of a Trigger contains a task with start_from_trigger enabled, | ||
| # it means we need to load the SerializedDagModel so we can build a RuntimeTaskInstance later on which | ||
| # will allow us to build a context on which we will render the templated fields. | ||
| if task.start_from_trigger: | ||
| return workloads.RunTrigger( | ||
| id=trigger.id, | ||
| classpath=trigger.classpath, | ||
| encrypted_kwargs=trigger.encrypted_kwargs, | ||
| ti=ser_ti, | ||
| timeout_after=trigger.task_instance.trigger_timeout, | ||
| dag_data=serialized_dag_model.data, | ||
| ) | ||
| return workloads.RunTrigger( | ||
| id=trigger.id, | ||
| classpath=trigger.classpath, | ||
| encrypted_kwargs=trigger.encrypted_kwargs, | ||
| ti=ser_ti, | ||
| timeout_after=trigger.task_instance.trigger_timeout, | ||
| ) | ||
|
|
||
| def update_triggers(self, requested_trigger_ids: set[int]): | ||
| """ | ||
| Request that we update what triggers we're running. | ||
|
|
@@ -666,8 +729,8 @@ def update_triggers(self, requested_trigger_ids: set[int]): | |
| adds them to the dequeues so the subprocess can actually mutate the running | ||
| trigger set. | ||
| """ | ||
| dag_bag = DBDagBag() | ||
| render_log_fname = log_filename_template_renderer() | ||
|
|
||
| known_trigger_ids = ( | ||
| self.running_triggers.union(x[0] for x in self.events) | ||
| .union(self.cancelling_triggers) | ||
|
|
@@ -678,60 +741,48 @@ def update_triggers(self, requested_trigger_ids: set[int]): | |
| new_trigger_ids = requested_trigger_ids - known_trigger_ids | ||
| cancel_trigger_ids = self.running_triggers - requested_trigger_ids | ||
| # Bulk-fetch new trigger records | ||
| new_triggers = Trigger.bulk_fetch(new_trigger_ids) | ||
| trigger_ids_with_non_task_associations = Trigger.fetch_trigger_ids_with_non_task_associations() | ||
| to_create: list[workloads.RunTrigger] = [] | ||
| # Add in new triggers | ||
| for new_id in new_trigger_ids: | ||
| # Check it didn't vanish in the meantime | ||
| if new_id not in new_triggers: | ||
| log.warning("Trigger disappeared before we could start it", id=new_id) | ||
| continue | ||
|
|
||
| new_trigger_orm = new_triggers[new_id] | ||
|
|
||
| # If the trigger is not associated to a task, an asset, or a callback, this means the TaskInstance | ||
| # row was updated by either Trigger.submit_event or Trigger.submit_failure | ||
| # and can happen when a single trigger Job is being run on multiple TriggerRunners | ||
| # in a High-Availability setup. | ||
| if new_trigger_orm.task_instance is None and new_id not in trigger_ids_with_non_task_associations: | ||
| log.info( | ||
| ( | ||
| "TaskInstance Trigger is None. It was likely updated by another trigger job. " | ||
| "Skipping trigger instantiation." | ||
| ), | ||
| id=new_id, | ||
| ) | ||
| continue | ||
|
|
||
| workload = workloads.RunTrigger( | ||
| classpath=new_trigger_orm.classpath, | ||
| id=new_id, | ||
| encrypted_kwargs=new_trigger_orm.encrypted_kwargs, | ||
| ti=None, | ||
| with create_session() as session: | ||
| # Bulk-fetch new trigger records | ||
| new_triggers = Trigger.bulk_fetch(new_trigger_ids, session=session) | ||
| trigger_ids_with_non_task_associations = Trigger.fetch_trigger_ids_with_non_task_associations( | ||
| session=session | ||
| ) | ||
| if new_trigger_orm.task_instance: | ||
| log_path = render_log_fname(ti=new_trigger_orm.task_instance) | ||
| if not new_trigger_orm.task_instance.dag_version_id: | ||
| # This is to handle 2 to 3 upgrade where TI.dag_version_id can be none | ||
| log.warning( | ||
| "TaskInstance associated with Trigger has no associated Dag Version, skipping the trigger", | ||
| ti_id=new_trigger_orm.task_instance.id, | ||
| ) | ||
| to_create: list[workloads.RunTrigger] = [] | ||
| # Add in new triggers | ||
| for new_trigger_id in new_trigger_ids: | ||
| # Check it didn't vanish in the meantime | ||
| if new_trigger_id not in new_triggers: | ||
| log.warning("Trigger disappeared before we could start it", id=new_trigger_id) | ||
| continue | ||
| ser_ti = TaskInstanceDTO.model_validate(new_trigger_orm.task_instance, from_attributes=True) | ||
| # When producing logs from TIs, include the job id producing the logs to disambiguate it. | ||
| self.logger_cache[new_id] = TriggerLoggingFactory( | ||
| log_path=f"{log_path}.trigger.{self.job.id}.log", | ||
| ti=ser_ti, # type: ignore | ||
| ) | ||
|
|
||
| workload.ti = ser_ti | ||
| workload.timeout_after = new_trigger_orm.task_instance.trigger_timeout | ||
| new_trigger_orm = new_triggers[new_trigger_id] | ||
|
|
||
| # If the trigger is not associated to a task, an asset, or a callback, this means the TaskInstance | ||
| # row was updated by either Trigger.submit_event or Trigger.submit_failure | ||
| # and can happen when a single trigger Job is being run on multiple TriggerRunners | ||
| # in a High-Availability setup. | ||
| if ( | ||
| new_trigger_orm.task_instance is None | ||
| and new_trigger_id not in trigger_ids_with_non_task_associations | ||
| ): | ||
| log.info( | ||
| ( | ||
| "TaskInstance of Trigger is None. It was likely updated by another trigger job. " | ||
| "Skipping trigger instantiation." | ||
| ), | ||
| id=new_trigger_id, | ||
| ) | ||
| continue | ||
|
|
||
| to_create.append(workload) | ||
| if workload := self._create_workload( | ||
| trigger=new_trigger_orm, | ||
| dag_bag=dag_bag, | ||
| render_log_fname=render_log_fname, | ||
| session=session, | ||
| ): | ||
| to_create.append(workload) | ||
|
|
||
| self.creating_triggers.extend(to_create) | ||
| self.creating_triggers.extend(to_create) | ||
|
|
||
| if cancel_trigger_ids: | ||
| # Enqueue orphaned triggers for cancellation | ||
|
|
@@ -986,9 +1037,19 @@ async def init_comms(self): | |
| raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}") | ||
|
|
||
| async def create_triggers(self): | ||
| def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance: | ||
dabla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| task = DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id) | ||
|
|
||
| # I need to recreate a TaskInstance from task_runner before invoking get_template_context (airflow.executors.workloads.TaskInstance) | ||
| return RuntimeTaskInstance.model_construct( | ||
dabla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| **workload.ti.model_dump(exclude_unset=True), | ||
| task=task, | ||
| ) | ||
dabla marked this conversation as resolved.
Show resolved
Hide resolved
dabla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| """Drain the to_create queue and create all new triggers that have been requested in the DB.""" | ||
dabla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| while self.to_create: | ||
| await asyncio.sleep(0) | ||
| context: Context | None = None | ||
| workload = self.to_create.popleft() | ||
| trigger_id = workload.id | ||
| if trigger_id in self.triggers: | ||
|
|
@@ -1016,24 +1077,32 @@ async def create_triggers(self): | |
| # that could cause None values in collections. | ||
| kw = Trigger._decrypt_kwargs(workload.encrypted_kwargs) | ||
| deserialised_kwargs = {k: smart_decode_trigger_kwargs(v) for k, v in kw.items()} | ||
| trigger_instance = trigger_class(**deserialised_kwargs) | ||
|
|
||
| if ti := workload.ti: | ||
| trigger_name = f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID {trigger_id})" | ||
| trigger_instance = trigger_class(**deserialised_kwargs) | ||
|
|
||
| if workload.dag_data: | ||
| runtime_ti = create_runtime_ti(workload.dag_data) | ||
| context = runtime_ti.get_template_context() | ||
| trigger_instance.task_instance = runtime_ti | ||
| else: | ||
| trigger_instance.task_instance = ti | ||
| else: | ||
| trigger_name = f"ID {trigger_id}" | ||
| trigger_instance = trigger_class(**deserialised_kwargs) | ||
| except TypeError as err: | ||
| self.log.error("Trigger failed to inflate", error=err) | ||
| self.failed_triggers.append((trigger_id, err)) | ||
| continue | ||
| trigger_instance.trigger_id = trigger_id | ||
| trigger_instance.triggerer_job_id = self.job_id | ||
| trigger_instance.task_instance = ti = workload.ti | ||
| trigger_instance.timeout_after = workload.timeout_after | ||
|
|
||
| trigger_name = ( | ||
| f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID {trigger_id})" | ||
| if ti | ||
| else f"ID {trigger_id}" | ||
| ) | ||
| self.triggers[trigger_id] = { | ||
| "task": asyncio.create_task( | ||
| self.run_trigger(trigger_id, trigger_instance, workload.timeout_after), name=trigger_name | ||
| self.run_trigger(trigger_id, trigger_instance, workload.timeout_after, context), | ||
| name=trigger_name, | ||
| ), | ||
| "is_watcher": isinstance(trigger_instance, BaseEventTrigger), | ||
| "name": trigger_name, | ||
|
|
@@ -1200,7 +1269,13 @@ async def block_watchdog(self): | |
| ) | ||
| Stats.incr("triggers.blocked_main_thread") | ||
|
|
||
| async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, timeout_after: datetime | None = None): | ||
| async def run_trigger( | ||
| self, | ||
| trigger_id: int, | ||
| trigger: BaseTrigger, | ||
| timeout_after: datetime | None = None, | ||
| context: Context | None = None, | ||
| ): | ||
| """Run a trigger (they are async generators) and push their events into our outbound event deque.""" | ||
| if not os.environ.get("AIRFLOW_DISABLE_GREENBACK_PORTAL", "").lower() == "true": | ||
| import greenback | ||
|
|
@@ -1213,6 +1288,9 @@ async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, timeout_after | |
| self.log.info("trigger %s starting", name) | ||
| with _make_trigger_span(ti=trigger.task_instance, trigger_id=trigger_id, name=name) as span: | ||
| try: | ||
| if context is not None: | ||
| trigger.render_template_fields(context=context) | ||
|
|
||
| async for event in trigger.run(): | ||
| await self.log.ainfo( | ||
| "Trigger fired event", name=self.triggers[trigger_id]["name"], result=event | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.