Skip to content

Commit ef3c6e1

Browse files
ekerstensEric Kerstens
andauthored
set active_tps in place (#239)
* set active_tps in place * Cancel getmany during rebalance * Fix nonetype error * Check for nonetype * Fix wait_results coro check * change == to is Co-authored-by: Eric Kerstens <ekerstens@expediagroup.com>
1 parent a79563a commit ef3c6e1

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

faust/transport/consumer.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ class Consumer(Service, ConsumerT):
429429
_commit_every: Optional[int]
430430
_n_acked: int = 0
431431

432-
_active_partitions: Optional[Set[TP]]
432+
_active_partitions: Set[TP]
433433
_paused_partitions: Set[TP]
434434
_buffered_partitions: Set[TP]
435435

@@ -495,7 +495,7 @@ def on_init_dependencies(self) -> Iterable[ServiceT]:
495495
return []
496496

497497
def _reset_state(self) -> None:
498-
self._active_partitions = None
498+
self._active_partitions = set()
499499
self._paused_partitions = set()
500500
self._buffered_partitions = set()
501501
self.can_resume_flow.clear()
@@ -516,9 +516,12 @@ def _get_active_partitions(self) -> Set[TP]:
516516
return tps
517517

518518
def _set_active_tps(self, tps: Set[TP]) -> Set[TP]:
519-
xtps = self._active_partitions = ensure_TPset(tps) # copy
520-
xtps.difference_update(self._paused_partitions)
521-
return xtps
519+
if self._active_partitions is None:
520+
self._active_partitions = set()
521+
self._active_partitions.clear()
522+
self._active_partitions.update(ensure_TPset(tps))
523+
self._active_partitions.difference_update(self._paused_partitions)
524+
return self._active_partitions
522525

523526
def on_buffer_full(self, tp: TP) -> None:
524527
# do not remove the partition when in recovery
@@ -730,6 +733,13 @@ async def getmany(self, timeout: float) -> AsyncIterator[Tuple[TP, Message]]:
730733
# convert timestamp to seconds from int milliseconds.
731734
yield tp, to_message(tp, record)
732735

736+
async def _wait_suspend(self):
737+
"""Wrapper around self.suspend_flow.wait() with no return value.
738+
739+
This allows for easily
740+
"""
741+
await self.suspend_flow.wait()
742+
733743
async def _wait_next_records(
734744
self, timeout: float
735745
) -> Tuple[Optional[RecordMap], Optional[Set[TP]]]:
@@ -750,10 +760,18 @@ async def _wait_next_records(
750760
# Fetch records only if active partitions to avoid the risk of
751761
# fetching all partitions in the beginning when none of the
752762
# partitions is paused/resumed.
753-
records = await self._getmany(
763+
_getmany = self._getmany(
754764
active_partitions=active_partitions,
755765
timeout=timeout,
756766
)
767+
wait_results = await self.wait_first(
768+
_getmany,
769+
self.suspend_flow.wait(),
770+
)
771+
for coro, result in zip(wait_results.done, wait_results.results):
772+
if coro is _getmany:
773+
records = result
774+
break
757775
else:
758776
# We should still release to the event loop
759777
await self.sleep(1)

0 commit comments

Comments
 (0)