Skip to content

Commit 20222ae

Browse files
committed
more helpful warning for leaky backpressure; removed isinstance
1 parent 554a66e commit 20222ae

File tree

5 files changed

+17
-17
lines changed

5 files changed

+17
-17
lines changed

examples/ezmsg_leaky_subscriber.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# rather than processing a growing backlog of stale messages.
1414

1515
import asyncio
16-
from collections.abc import AsyncGenerator
16+
import typing
17+
1718
from dataclasses import dataclass, field
1819

1920
import ezmsg.core as ez
@@ -40,9 +41,7 @@ class FastPublisher(ez.Unit):
4041
OUTPUT = ez.OutputStream(TimestampedMessage, num_buffers=32)
4142

4243
@ez.publisher(OUTPUT)
43-
async def publish(self) -> AsyncGenerator:
44-
# Small delay to ensure subscriber is ready
45-
await asyncio.sleep(0.5)
44+
async def publish(self) -> typing.AsyncGenerator:
4645

4746
for seq in range(self.SETTINGS.num_messages):
4847
msg = TimestampedMessage(seq=seq)
@@ -60,8 +59,8 @@ class SlowSubscriberSettings(ez.Settings):
6059

6160

6261
class SlowSubscriberState(ez.State):
62+
received_seqs: list
6363
received_count: int = 0
64-
received_seqs: list = None
6564
total_latency: float = 0.0
6665

6766

@@ -76,9 +75,8 @@ class SlowSubscriber(ez.Unit):
7675
SETTINGS = SlowSubscriberSettings
7776
STATE = SlowSubscriberState
7877

79-
# Leaky input stream with max queue of 3 messages
80-
# When the queue fills up, oldest messages are dropped
81-
INPUT = ez.InputStream(TimestampedMessage, leaky=True, max_queue=3)
78+
# Leaky input stream; oldest messages are dropped
79+
INPUT = ez.InputStream(TimestampedMessage, leaky=True)
8280

8381
async def initialize(self) -> None:
8482
self.STATE.received_seqs = []

src/ezmsg/core/backendprocess.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,15 +416,14 @@ async def handle_subscriber(
416416
# allowing publishers to continue without blocking during slow processing.
417417
# Non-leaky subscribers use recv_zero_copy() to hold backpressure during
418418
# processing, which provides zero-copy performance but applies backpressure.
419-
is_leaky = isinstance(sub._incoming, LeakyQueue)
420419

421420
while True:
422421
if not callables:
423422
sub.close()
424423
await sub.wait_closed()
425424
break
426425

427-
if is_leaky:
426+
if sub.leaky:
428427
msg = await sub.recv()
429428
try:
430429
for callable in list(callables):

src/ezmsg/core/messagechannel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ async def create(
164164
writer.write(Command.CHANNEL.value)
165165
writer.write(encode_str(id_str))
166166

167+
topic = await read_str(reader)
168+
167169
shm = None
168170
shm_name = await read_str(reader)
169171
try:
@@ -184,6 +186,7 @@ async def create(
184186
assert num_buffers > 0, "publisher reports invalid num_buffers"
185187

186188
chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address, _guard=cls._SENTINEL)
189+
chan.topic = topic
187190

188191
chan._graph_task = asyncio.create_task(
189192
chan._graph_connection(graph_reader, graph_writer),

src/ezmsg/core/pubclient.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ async def _channel_connect(
329329
if cmd == Command.CHANNEL.value:
330330
channel_id_str = await read_str(reader)
331331
channel_id = UUID(channel_id_str)
332+
writer.write(encode_str(self.topic))
332333
writer.write(encode_str(self._shm.name))
333334
shm_ok = await reader.read(1) == Command.SHM_OK.value
334335
pid = await read_int(reader)

src/ezmsg/core/subclient.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Subscriber:
3535

3636
id: UUID
3737
topic: str
38+
leaky: bool
3839

3940
_graph_address: AddressType | None
4041
_graph_task: asyncio.Task[None]
@@ -124,11 +125,12 @@ def __init__(
124125
)
125126
self.id = id
126127
self.topic = topic
128+
self.leaky = leaky
127129
self._graph_address = graph_address
128130

129131
self._cur_pubs = set()
130132
self._channels = dict()
131-
if leaky:
133+
if self.leaky:
132134
self._incoming = LeakyQueue(
133135
1 if max_queue is None else max_queue, self._handle_dropped_notification
134136
)
@@ -227,15 +229,12 @@ async def _graph_connection(
227229
pub_id, self.id, self._incoming, self._graph_address
228230
)
229231

230-
if (
231-
isinstance(self._incoming, LeakyQueue)
232-
and self._incoming.maxsize >= channel.num_buffers
233-
):
232+
if self.leaky and self._incoming.maxsize >= channel.num_buffers:
234233
logger.warning(
235234
f"Leaky Subscriber {self.topic} may cause "
236-
f"backpressure in Publisher {channel.topic}."
235+
f"backpressure in Publisher {channel.topic}. "
237236
f"Subscriber's max queue size ({self._incoming.maxsize}) >= "
238-
f"Publisher's num_buffers ({channel.num_buffers})"
237+
f"Publisher's num_buffers ({channel.num_buffers})."
239238
)
240239

241240
self._channels[pub_id] = channel

0 commit comments

Comments
 (0)