diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 5f5073c916b68..644b0f5fa2a6d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -892,15 +892,13 @@ def get_task_instance_count( query = query.where(TI.run_id.in_(run_ids)) if task_group_id: - group_tasks = _get_group_tasks(dag_id, task_group_id, session, dag_bag, logical_dates, run_ids) + group_tasks = _get_group_tasks( + dag_id, task_group_id, session, dag_bag, logical_dates, run_ids, map_index + ) # Get unique (task_id, map_index) pairs - task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks] - if map_index is not None: - task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks if ti.map_index == map_index] - if not task_map_pairs: # If no task group tasks found, default to checking the task group ID itself # This matches the behavior in _get_external_task_group_task_ids @@ -1000,15 +998,18 @@ def get_task_instance_states( if run_ids: query = query.where(TI.run_id.in_(run_ids)) + if map_index is not None: + query = query.where(TI.map_index == map_index) + results = session.scalars(query).all() if task_group_id: - group_tasks = _get_group_tasks(dag_id, task_group_id, session, dag_bag, logical_dates, run_ids) + group_tasks = _get_group_tasks( + dag_id, task_group_id, session, dag_bag, logical_dates, run_ids, map_index + ) results = results + group_tasks if task_ids else group_tasks - if map_index is not None: - results = [task for task in results if task.map_index == map_index] [ run_id_task_state_map[task.run_id].update( {task.task_id: task.state} @@ -1049,7 +1050,13 @@ def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool: def _get_group_tasks( - dag_id: str, task_group_id: str, session: SessionDep, dag_bag: DagBagDep, logical_dates=None, run_ids=None + dag_id: str, + task_group_id: str, + session: SessionDep, + dag_bag: DagBagDep, + logical_dates=None, + run_ids=None, + map_index: int | None = None, ): # Get all tasks in the task group dag = get_latest_version_of_dag(dag_bag, dag_id, session, include_reason=True) @@ -1070,6 +1077,7 @@ def _get_group_tasks( TI.task_id.in_(task.task_id for task in task_group.iter_tasks()), *([TI.logical_date.in_(logical_dates)] if logical_dates else []), *([TI.run_id.in_(run_ids)] if run_ids else []), + *([TI.map_index == map_index] if map_index is not None else []), ) ).all() diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 7f766ede71e4d..1ab690239fe6c 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -2336,6 +2336,7 @@ def add_one(x): ("map_index", "dynamic_task_args", "task_ids", "task_group_name", "expected_count"), ( pytest.param(None, [1, 2, 3], None, None, 5, id="use-default-map-index-None"), + pytest.param(0, [1, 2, 3], None, None, 1, id="with-map-index-0-no-task-group"), pytest.param(-1, [1, 2, 3], ["task1"], None, 1, id="with-task-ids-and-map-index-(-1)"), pytest.param(None, [1, 2, 3], None, "group1", 4, id="with-task-group-id-and-map-index-None"), pytest.param(0, [1, 2, 3], None, "group1", 1, id="with-task-group-id-and-map-index-0"), @@ -2853,6 +2854,15 @@ def add_one(x): }, id="with-default-map-index-None", ), + pytest.param( + 0, + [1, 2, 3], + None, + None, + {"-1": State.SUCCESS, "0": State.FAILED, "1": State.SUCCESS, "2": State.SUCCESS}, + {"group1.add_one_0": "failed"}, + id="with-map-index-0-no-task-group", + ), pytest.param( -1, [1, 2, 3],