Skip to content

Commit f72d2aa

Browse files
committed
Fix AwaitTrigger and airflow e2e test
1 parent 96e937d commit f72d2aa

4 files changed

Lines changed: 84 additions & 42 deletions

File tree

airflow-e2e-tests/tests/airflow_e2e_tests/conftest.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _setup_opensearch_integration(dot_env_file, tmp_dir):
123123

124124

125125
def _copy_kafka_files(tmp_dir):
126-
"""Copy Kafka compose file and init script into the temp directory."""
126+
"""Copy Kafka compose file, init script, and provider source into the temp directory."""
127127
copyfile(KAFKA_DIR_PATH.parent / "kafka.yml", tmp_dir / "kafka.yml")
128128

129129
kafka_dir = tmp_dir / "kafka"
@@ -149,12 +149,7 @@ def _setup_event_driven_integration(dot_env_file, tmp_dir):
149149
}
150150
)
151151

152-
dot_env_file.write_text(
153-
f"AIRFLOW_UID={os.getuid()}\n"
154-
f"AIRFLOW_CONN_KAFKA_DEFAULT='{kafka_conn}'\n"
155-
"_PIP_ADDITIONAL_REQUIREMENTS="
156-
"apache-airflow-providers-apache-kafka apache-airflow-providers-common-messaging\n"
157-
)
152+
dot_env_file.write_text(f"AIRFLOW_UID={os.getuid()}\nAIRFLOW_CONN_KAFKA_DEFAULT='{kafka_conn}'\n")
158153
os.environ["ENV_FILE_PATH"] = str(dot_env_file)
159154

160155

airflow-e2e-tests/tests/airflow_e2e_tests/dags/example_event_driven.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import json
2121
from datetime import timedelta
22-
from typing import TYPE_CHECKING
22+
from typing import TYPE_CHECKING, cast
2323

2424
import pendulum
2525

@@ -133,7 +133,7 @@ def process_message(**context) -> bool:
133133
for event in triggering_asset_events[kafka_cdc_asset]:
134134
# Get the message from the TriggerEvent payload
135135
print(f"Asset event: {event}")
136-
process_one_message(event.extra["payload"])
136+
process_one_message(cast("str", event.extra["payload"]))
137137
return True
138138

139139
@task.short_circuit(trigger_rule="all_done")

airflow-e2e-tests/tests/airflow_e2e_tests/event_driven_tests/test_event_driven.py

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,11 @@ class TestEventDrivenDag:
4949

5050
airflow_client = AirflowClient()
5151

52-
# ------------------------------------------------------------------
53-
# Helpers
54-
# ------------------------------------------------------------------
55-
56-
def _wait_for_kafka_consumer_group(
57-
self, compose_instance, group_id: str, timeout: int = 60, check_interval: int = 3
58-
):
59-
"""Poll until the Kafka consumer group is registered, indicating the trigger is active."""
52+
def _wait_for_kafka_consumer_group(self, compose_instance, timeout: int = 60, check_interval: int = 3):
53+
"""Poll until any Kafka consumer group is registered, indicating the trigger is active."""
6054
start = time.monotonic()
6155
while time.monotonic() - start < timeout:
62-
stdout, _ = compose_instance.exec_in_container(
56+
stdout, _, _ = compose_instance.exec_in_container(
6357
command=[
6458
"kafka-consumer-groups",
6559
"--bootstrap-server",
@@ -69,10 +63,12 @@ def _wait_for_kafka_consumer_group(
6963
service_name="broker",
7064
)
7165
output = stdout.decode() if isinstance(stdout, bytes) else stdout
72-
if group_id in output:
66+
# Any non-empty group listing means the trigger's consumer has registered
67+
groups = [line.strip() for line in output.strip().splitlines() if line.strip()]
68+
if groups:
7369
return
7470
time.sleep(check_interval)
75-
raise TimeoutError(f"Kafka consumer group '{group_id}' not registered within {timeout}s")
71+
raise TimeoutError(f"No Kafka consumer group registered within {timeout}s")
7672

7773
def _wait_for_consumer_dag_runs(
7874
self, expected_count: int, timeout: int = 600, check_interval: int = 10
@@ -110,7 +106,7 @@ def _wait_for_consumer_dag_runs(
110106

111107
def _get_topic_offset(self, compose_instance, topic: str) -> int:
112108
"""Return the current end-offset of *topic* via ``kafka-get-offsets`` inside the broker."""
113-
stdout, _ = compose_instance.exec_in_container(
109+
stdout, _, _ = compose_instance.exec_in_container(
114110
command=[
115111
"kafka-get-offsets",
116112
"--bootstrap-server",
@@ -123,9 +119,21 @@ def _get_topic_offset(self, compose_instance, topic: str) -> int:
123119
output = stdout.decode() if isinstance(stdout, bytes) else stdout
124120
return _parse_topic_offset(output, topic)
125121

126-
# ------------------------------------------------------------------
127-
# Test
128-
# ------------------------------------------------------------------
122+
def _wait_for_topic_offset(
123+
self, compose_instance, topic: str, expected_offset: int, timeout: int = 30, check_interval: int = 2
124+
) -> int:
125+
"""Poll until *topic* reaches *expected_offset*, then return the offset."""
126+
start = time.monotonic()
127+
offset = self._get_topic_offset(compose_instance, topic)
128+
while offset < expected_offset and time.monotonic() - start < timeout:
129+
time.sleep(check_interval)
130+
offset = self._get_topic_offset(compose_instance, topic)
131+
return offset
132+
133+
def _get_task_states(self, run_id: str) -> dict[str, str]:
134+
"""Return a mapping of task_id -> state for a consumer DAG run."""
135+
response = self.airflow_client.get_task_instances(CONSUMER_DAG_ID, run_id)
136+
return {ti["task_id"]: ti["state"] for ti in response["task_instances"]}
129137

130138
def test_producer_triggers_consumer_and_kafka_offsets(self, compose_instance):
131139
"""Trigger the producer once and verify 9 consumer runs and Kafka offsets.
@@ -135,16 +143,19 @@ def test_producer_triggers_consumer_and_kafka_offsets(self, compose_instance):
135143
2. Wait for the Kafka MessageQueueTrigger to begin polling.
136144
3. Trigger the producer DAG and wait for it to succeed.
137145
4. Wait for 9 consumer DAG runs to reach a terminal state.
138-
5. Verify that the ``fizz_buzz`` topic has offset 9 (all messages produced).
139-
6. Verify that the ``dlq`` topic has offset 1 (the malformed message).
146+
5. All 9 DAG runs succeed. Verify task-level behavior:
147+
- 1 run has a failed ``process_message`` task and executes ``handle_dlq``.
148+
- 8 runs succeed on ``process_message`` and skip ``handle_dlq``.
149+
6. Verify that the ``fizz_buzz`` topic has offset 9 (all messages produced).
150+
7. Verify that the ``dlq`` topic has offset 1 (the malformed message).
140151
"""
141152
# 1. Unpause consumer so the triggerer registers the AssetWatcher
142153
self.airflow_client.un_pause_dag(CONSUMER_DAG_ID)
143154

144155
# 2. Wait for the triggerer to start the MessageQueueTrigger and subscribe.
145156
# The trigger uses poll_interval=1 and auto.offset.reset=latest so it
146157
# must be actively polling before the producer writes.
147-
self._wait_for_kafka_consumer_group(compose_instance, "kafka_default_group")
158+
self._wait_for_kafka_consumer_group(compose_instance)
148159

149160
# 3. Trigger producer and wait for it to complete
150161
producer_state = self.airflow_client.trigger_dag_and_wait(PRODUCER_DAG_ID)
@@ -156,28 +167,63 @@ def test_producer_triggers_consumer_and_kafka_offsets(self, compose_instance):
156167
f"Expected {EXPECTED_CONSUMER_RUNS} consumer runs, got {len(consumer_runs)}"
157168
)
158169

159-
# 5. Verify consumer run outcomes:
160-
# - 8 runs process valid orders and succeed
161-
# - 1 run hits the malformed message, process_message fails after retries,
162-
# then handle_dlq sends it to the DLQ; the overall run is still "failed"
163-
success_runs = [r for r in consumer_runs if r["state"] == "success"]
164-
failed_runs = [r for r in consumer_runs if r["state"] == "failed"]
165-
assert len(success_runs) == 8, (
166-
f"Expected 8 successful consumer runs, got {len(success_runs)}. "
167-
f"States: {[(r['dag_run_id'], r['state']) for r in consumer_runs]}"
170+
# 5. All 9 DAG runs should succeed
171+
for run in consumer_runs:
172+
assert run["state"] == "success", (
173+
f"Expected all consumer runs to succeed, but run {run['dag_run_id']} "
174+
f"has state '{run['state']}'"
175+
)
176+
177+
# 6. Verify task-level behavior per run:
178+
# - 1 run: process_message fails (malformed message), handle_dlq executes
179+
# - 8 runs: process_message succeeds, handle_dlq is skipped
180+
dlq_runs = []
181+
normal_runs = []
182+
for run in consumer_runs:
183+
ti_states = self._get_task_states(run["dag_run_id"])
184+
if ti_states.get("process_message") == "failed":
185+
dlq_runs.append(run["dag_run_id"])
186+
assert ti_states.get("should_handle_dlq") == "success", (
187+
f"Run {run['dag_run_id']}: expected should_handle_dlq=success, "
188+
f"got '{ti_states.get('should_handle_dlq')}'"
189+
)
190+
assert ti_states.get("handle_dlq") == "success", (
191+
f"Run {run['dag_run_id']}: expected handle_dlq=success, "
192+
f"got '{ti_states.get('handle_dlq')}'"
193+
)
194+
else:
195+
normal_runs.append(run["dag_run_id"])
196+
assert ti_states.get("process_message") == "success", (
197+
f"Run {run['dag_run_id']}: expected process_message=success, "
198+
f"got '{ti_states.get('process_message')}'"
199+
)
200+
assert ti_states.get("should_handle_dlq") == "success", (
201+
f"Run {run['dag_run_id']}: expected should_handle_dlq=success, "
202+
f"got '{ti_states.get('should_handle_dlq')}'"
203+
)
204+
assert ti_states.get("handle_dlq") == "skipped", (
205+
f"Run {run['dag_run_id']}: expected handle_dlq=skipped, "
206+
f"got '{ti_states.get('handle_dlq')}'"
207+
)
208+
209+
assert len(dlq_runs) == 1, (
210+
f"Expected exactly 1 run with failed process_message (DLQ path), got {len(dlq_runs)}: {dlq_runs}"
168211
)
169-
assert len(failed_runs) == 1, (
170-
f"Expected 1 failed consumer run (malformed message), got {len(failed_runs)}. "
171-
f"States: {[(r['dag_run_id'], r['state']) for r in consumer_runs]}"
212+
assert len(normal_runs) == 8, (
213+
f"Expected 8 runs with successful process_message, got {len(normal_runs)}: {normal_runs}"
172214
)
173215

174-
# 6. Verify Kafka topic offsets
175-
fizz_buzz_offset = self._get_topic_offset(compose_instance, "fizz_buzz")
216+
# 7. Verify Kafka topic offsets
217+
# The DLQ message is produced by the last consumer run to complete, so
218+
# kafka-get-offsets may briefly report a stale offset; poll with a short timeout.
219+
fizz_buzz_offset = self._wait_for_topic_offset(
220+
compose_instance, "fizz_buzz", EXPECTED_FIZZ_BUZZ_OFFSET
221+
)
176222
assert fizz_buzz_offset == EXPECTED_FIZZ_BUZZ_OFFSET, (
177223
f"Expected fizz_buzz offset {EXPECTED_FIZZ_BUZZ_OFFSET}, got {fizz_buzz_offset}"
178224
)
179225

180-
dlq_offset = self._get_topic_offset(compose_instance, "dlq")
226+
dlq_offset = self._wait_for_topic_offset(compose_instance, "dlq", EXPECTED_DLQ_OFFSET)
181227
assert dlq_offset == EXPECTED_DLQ_OFFSET, (
182228
f"Expected dlq offset {EXPECTED_DLQ_OFFSET}, got {dlq_offset}"
183229
)

providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
poll_interval: float = 5,
8383
commit_offset: bool = True,
8484
) -> None:
85+
super().__init__()
8586
self.topics = topics
8687
self.apply_function = apply_function
8788
self.apply_function_args = apply_function_args or ()

0 commit comments

Comments
 (0)