@@ -49,17 +49,11 @@ class TestEventDrivenDag:
4949
5050 airflow_client = AirflowClient ()
5151
52- # ------------------------------------------------------------------
53- # Helpers
54- # ------------------------------------------------------------------
55-
56- def _wait_for_kafka_consumer_group (
57- self , compose_instance , group_id : str , timeout : int = 60 , check_interval : int = 3
58- ):
59- """Poll until the Kafka consumer group is registered, indicating the trigger is active."""
52+ def _wait_for_kafka_consumer_group (self , compose_instance , timeout : int = 60 , check_interval : int = 3 ):
53+ """Poll until any Kafka consumer group is registered, indicating the trigger is active."""
6054 start = time .monotonic ()
6155 while time .monotonic () - start < timeout :
62- stdout , _ = compose_instance .exec_in_container (
56+ stdout , _ , _ = compose_instance .exec_in_container (
6357 command = [
6458 "kafka-consumer-groups" ,
6559 "--bootstrap-server" ,
@@ -69,10 +63,12 @@ def _wait_for_kafka_consumer_group(
6963 service_name = "broker" ,
7064 )
7165 output = stdout .decode () if isinstance (stdout , bytes ) else stdout
72- if group_id in output :
66+ # Any non-empty group listing means the trigger's consumer has registered
67+ groups = [line .strip () for line in output .strip ().splitlines () if line .strip ()]
68+ if groups :
7369 return
7470 time .sleep (check_interval )
75- raise TimeoutError (f"Kafka consumer group ' { group_id } ' not registered within { timeout } s" )
71+ raise TimeoutError (f"No Kafka consumer group registered within { timeout } s" )
7672
7773 def _wait_for_consumer_dag_runs (
7874 self , expected_count : int , timeout : int = 600 , check_interval : int = 10
@@ -110,7 +106,7 @@ def _wait_for_consumer_dag_runs(
110106
111107 def _get_topic_offset (self , compose_instance , topic : str ) -> int :
112108 """Return the current end-offset of *topic* via ``kafka-get-offsets`` inside the broker."""
113- stdout , _ = compose_instance .exec_in_container (
109+ stdout , _ , _ = compose_instance .exec_in_container (
114110 command = [
115111 "kafka-get-offsets" ,
116112 "--bootstrap-server" ,
@@ -123,9 +119,21 @@ def _get_topic_offset(self, compose_instance, topic: str) -> int:
123119 output = stdout .decode () if isinstance (stdout , bytes ) else stdout
124120 return _parse_topic_offset (output , topic )
125121
126- # ------------------------------------------------------------------
127- # Test
128- # ------------------------------------------------------------------
122+ def _wait_for_topic_offset (
123+ self , compose_instance , topic : str , expected_offset : int , timeout : int = 30 , check_interval : int = 2
124+ ) -> int :
125+ """Poll until *topic* reaches *expected_offset*, then return the offset."""
126+ start = time .monotonic ()
127+ offset = self ._get_topic_offset (compose_instance , topic )
128+ while offset < expected_offset and time .monotonic () - start < timeout :
129+ time .sleep (check_interval )
130+ offset = self ._get_topic_offset (compose_instance , topic )
131+ return offset
132+
133+ def _get_task_states (self , run_id : str ) -> dict [str , str ]:
134+ """Return a mapping of task_id -> state for a consumer DAG run."""
135+ response = self .airflow_client .get_task_instances (CONSUMER_DAG_ID , run_id )
136+ return {ti ["task_id" ]: ti ["state" ] for ti in response ["task_instances" ]}
129137
130138 def test_producer_triggers_consumer_and_kafka_offsets (self , compose_instance ):
131139 """Trigger the producer once and verify 9 consumer runs and Kafka offsets.
@@ -135,16 +143,19 @@ def test_producer_triggers_consumer_and_kafka_offsets(self, compose_instance):
135143 2. Wait for the Kafka MessageQueueTrigger to begin polling.
136144 3. Trigger the producer DAG and wait for it to succeed.
137145 4. Wait for 9 consumer DAG runs to reach a terminal state.
138- 5. Verify that the ``fizz_buzz`` topic has offset 9 (all messages produced).
139- 6. Verify that the ``dlq`` topic has offset 1 (the malformed message).
146+ 5. All 9 DAG runs succeed. Verify task-level behavior:
147+ - 1 run has a failed ``process_message`` task and executes ``handle_dlq``.
148+ - 8 runs succeed on ``process_message`` and skip ``handle_dlq``.
149+ 6. Verify that the ``fizz_buzz`` topic has offset 9 (all messages produced).
150+ 7. Verify that the ``dlq`` topic has offset 1 (the malformed message).
140151 """
141152 # 1. Unpause consumer so the triggerer registers the AssetWatcher
142153 self .airflow_client .un_pause_dag (CONSUMER_DAG_ID )
143154
144155 # 2. Wait for the triggerer to start the MessageQueueTrigger and subscribe.
145156 # The trigger uses poll_interval=1 and auto.offset.reset=latest so it
146157 # must be actively polling before the producer writes.
147- self ._wait_for_kafka_consumer_group (compose_instance , "kafka_default_group" )
158+ self ._wait_for_kafka_consumer_group (compose_instance )
148159
149160 # 3. Trigger producer and wait for it to complete
150161 producer_state = self .airflow_client .trigger_dag_and_wait (PRODUCER_DAG_ID )
@@ -156,28 +167,63 @@ def test_producer_triggers_consumer_and_kafka_offsets(self, compose_instance):
156167 f"Expected { EXPECTED_CONSUMER_RUNS } consumer runs, got { len (consumer_runs )} "
157168 )
158169
159- # 5. Verify consumer run outcomes:
160- # - 8 runs process valid orders and succeed
161- # - 1 run hits the malformed message, process_message fails after retries,
162- # then handle_dlq sends it to the DLQ; the overall run is still "failed"
163- success_runs = [r for r in consumer_runs if r ["state" ] == "success" ]
164- failed_runs = [r for r in consumer_runs if r ["state" ] == "failed" ]
165- assert len (success_runs ) == 8 , (
166- f"Expected 8 successful consumer runs, got { len (success_runs )} . "
167- f"States: { [(r ['dag_run_id' ], r ['state' ]) for r in consumer_runs ]} "
170+ # 5. All 9 DAG runs should succeed
171+ for run in consumer_runs :
172+ assert run ["state" ] == "success" , (
173+ f"Expected all consumer runs to succeed, but run { run ['dag_run_id' ]} "
174+ f"has state '{ run ['state' ]} '"
175+ )
176+
177+ # 6. Verify task-level behavior per run:
178+ # - 1 run: process_message fails (malformed message), handle_dlq executes
179+ # - 8 runs: process_message succeeds, handle_dlq is skipped
180+ dlq_runs = []
181+ normal_runs = []
182+ for run in consumer_runs :
183+ ti_states = self ._get_task_states (run ["dag_run_id" ])
184+ if ti_states .get ("process_message" ) == "failed" :
185+ dlq_runs .append (run ["dag_run_id" ])
186+ assert ti_states .get ("should_handle_dlq" ) == "success" , (
187+ f"Run { run ['dag_run_id' ]} : expected should_handle_dlq=success, "
188+ f"got '{ ti_states .get ('should_handle_dlq' )} '"
189+ )
190+ assert ti_states .get ("handle_dlq" ) == "success" , (
191+ f"Run { run ['dag_run_id' ]} : expected handle_dlq=success, "
192+ f"got '{ ti_states .get ('handle_dlq' )} '"
193+ )
194+ else :
195+ normal_runs .append (run ["dag_run_id" ])
196+ assert ti_states .get ("process_message" ) == "success" , (
197+ f"Run { run ['dag_run_id' ]} : expected process_message=success, "
198+ f"got '{ ti_states .get ('process_message' )} '"
199+ )
200+ assert ti_states .get ("should_handle_dlq" ) == "success" , (
201+ f"Run { run ['dag_run_id' ]} : expected should_handle_dlq=success, "
202+ f"got '{ ti_states .get ('should_handle_dlq' )} '"
203+ )
204+ assert ti_states .get ("handle_dlq" ) == "skipped" , (
205+ f"Run { run ['dag_run_id' ]} : expected handle_dlq=skipped, "
206+ f"got '{ ti_states .get ('handle_dlq' )} '"
207+ )
208+
209+ assert len (dlq_runs ) == 1 , (
210+ f"Expected exactly 1 run with failed process_message (DLQ path), got { len (dlq_runs )} : { dlq_runs } "
168211 )
169- assert len (failed_runs ) == 1 , (
170- f"Expected 1 failed consumer run (malformed message), got { len (failed_runs )} . "
171- f"States: { [(r ['dag_run_id' ], r ['state' ]) for r in consumer_runs ]} "
212+ assert len (normal_runs ) == 8 , (
213+ f"Expected 8 runs with successful process_message, got { len (normal_runs )} : { normal_runs } "
172214 )
173215
174- # 6. Verify Kafka topic offsets
175- fizz_buzz_offset = self ._get_topic_offset (compose_instance , "fizz_buzz" )
216+ # 7. Verify Kafka topic offsets
217+ # The DLQ message is produced by the last consumer run to complete, so
218+ # kafka-get-offsets may briefly report a stale offset; poll with a short timeout.
219+ fizz_buzz_offset = self ._wait_for_topic_offset (
220+ compose_instance , "fizz_buzz" , EXPECTED_FIZZ_BUZZ_OFFSET
221+ )
176222 assert fizz_buzz_offset == EXPECTED_FIZZ_BUZZ_OFFSET , (
177223 f"Expected fizz_buzz offset { EXPECTED_FIZZ_BUZZ_OFFSET } , got { fizz_buzz_offset } "
178224 )
179225
180- dlq_offset = self ._get_topic_offset (compose_instance , "dlq" )
226+ dlq_offset = self ._wait_for_topic_offset (compose_instance , "dlq" , EXPECTED_DLQ_OFFSET )
181227 assert dlq_offset == EXPECTED_DLQ_OFFSET , (
182228 f"Expected dlq offset { EXPECTED_DLQ_OFFSET } , got { dlq_offset } "
183229 )
0 commit comments