@@ -429,7 +429,7 @@ class Consumer(Service, ConsumerT):
429429 _commit_every : Optional [int ]
430430 _n_acked : int = 0
431431
432- _active_partitions : Optional [ Set [TP ] ]
432+ _active_partitions : Set [TP ]
433433 _paused_partitions : Set [TP ]
434434 _buffered_partitions : Set [TP ]
435435
@@ -495,7 +495,7 @@ def on_init_dependencies(self) -> Iterable[ServiceT]:
495495 return []
496496
497497 def _reset_state (self ) -> None :
498- self ._active_partitions = None
498+ self ._active_partitions = set ()
499499 self ._paused_partitions = set ()
500500 self ._buffered_partitions = set ()
501501 self .can_resume_flow .clear ()
@@ -516,9 +516,12 @@ def _get_active_partitions(self) -> Set[TP]:
516516 return tps
517517
518518 def _set_active_tps (self , tps : Set [TP ]) -> Set [TP ]:
519- xtps = self ._active_partitions = ensure_TPset (tps ) # copy
520- xtps .difference_update (self ._paused_partitions )
521- return xtps
519+ if self ._active_partitions is None :
520+ self ._active_partitions = set ()
521+ self ._active_partitions .clear ()
522+ self ._active_partitions .update (ensure_TPset (tps ))
523+ self ._active_partitions .difference_update (self ._paused_partitions )
524+ return self ._active_partitions
522525
523526 def on_buffer_full (self , tp : TP ) -> None :
524527 # do not remove the partition when in recovery
@@ -730,6 +733,13 @@ async def getmany(self, timeout: float) -> AsyncIterator[Tuple[TP, Message]]:
730733 # convert timestamp to seconds from int milliseconds.
731734 yield tp , to_message (tp , record )
732735
736+ async def _wait_suspend (self ):
737+ """Wrapper around self.suspend_flow.wait() with no return value.
738+
739+ This allows for easily
740+ """
741+ await self .suspend_flow .wait ()
742+
733743 async def _wait_next_records (
734744 self , timeout : float
735745 ) -> Tuple [Optional [RecordMap ], Optional [Set [TP ]]]:
@@ -750,10 +760,18 @@ async def _wait_next_records(
750760 # Fetch records only if active partitions to avoid the risk of
751761 # fetching all partitions in the beginning when none of the
752762 # partitions is paused/resumed.
753- records = await self ._getmany (
763+ _getmany = self ._getmany (
754764 active_partitions = active_partitions ,
755765 timeout = timeout ,
756766 )
767+ wait_results = await self .wait_first (
768+ _getmany ,
769+ self .suspend_flow .wait (),
770+ )
771+ for coro , result in zip (wait_results .done , wait_results .results ):
772+ if coro is _getmany :
773+ records = result
774+ break
757775 else :
758776 # We should still release to the event loop
759777 await self .sleep (1 )
0 commit comments