Skip to content

Commit e524eaa

Browse files
Merge pull request #208 from ezmsg-org/feat/leaky-sub
feat: implement leaky subscribers.
2 parents eaf8068 + 20222ae commit e524eaa

File tree

9 files changed

+790
-25
lines changed

9 files changed

+790
-25
lines changed

examples/ezmsg_leaky_subscriber.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Leaky Subscriber Example
2+
#
3+
# This example demonstrates the "leaky subscriber" feature, which allows
4+
# slow consumers to drop old messages rather than blocking fast producers.
5+
#
6+
# Scenario:
7+
# - A fast publisher produces messages at ~10 Hz (every 100ms)
8+
# - A slow subscriber processes messages at ~1 Hz (1000ms per message)
9+
# - Without leaky mode: the publisher would be blocked by backpressure
10+
# - With leaky mode: old messages are dropped, subscriber always gets recent data
11+
#
12+
# This is useful for real-time applications where you want the latest data
13+
# rather than processing a growing backlog of stale messages.
14+
15+
import asyncio
16+
import typing
17+
18+
from dataclasses import dataclass, field
19+
20+
import ezmsg.core as ez
21+
22+
23+
@dataclass
24+
class TimestampedMessage:
25+
"""A message with sequence number and timestamp for tracking latency."""
26+
27+
seq: int
28+
created_at: float = field(default_factory=lambda: asyncio.get_event_loop().time())
29+
30+
31+
class FastPublisherSettings(ez.Settings):
32+
num_messages: int = 20
33+
publish_interval_sec: float = 0.1 # 10 Hz
34+
35+
36+
class FastPublisher(ez.Unit):
37+
"""Publishes messages at ~10 Hz."""
38+
39+
SETTINGS = FastPublisherSettings
40+
41+
OUTPUT = ez.OutputStream(TimestampedMessage, num_buffers=32)
42+
43+
@ez.publisher(OUTPUT)
44+
async def publish(self) -> typing.AsyncGenerator:
45+
46+
for seq in range(self.SETTINGS.num_messages):
47+
msg = TimestampedMessage(seq=seq)
48+
print(f"[Publisher] Sending seq={seq}", flush=True)
49+
yield (self.OUTPUT, msg)
50+
await asyncio.sleep(self.SETTINGS.publish_interval_sec)
51+
52+
print("[Publisher] Done sending all messages", flush=True)
53+
raise ez.Complete
54+
55+
56+
class SlowSubscriberSettings(ez.Settings):
57+
process_time_sec: float = 1.0 # Simulates slow processing at ~1 Hz
58+
expected_messages: int = 20
59+
60+
61+
class SlowSubscriberState(ez.State):
62+
received_seqs: list
63+
received_count: int = 0
64+
total_latency: float = 0.0
65+
66+
67+
class SlowSubscriber(ez.Unit):
68+
"""
69+
A slow subscriber that takes 1 second to process each message.
70+
71+
Uses a leaky InputStream to drop old messages when it can't keep up,
72+
ensuring it always processes relatively recent data.
73+
"""
74+
75+
SETTINGS = SlowSubscriberSettings
76+
STATE = SlowSubscriberState
77+
78+
# Leaky input stream; oldest messages are dropped
79+
INPUT = ez.InputStream(TimestampedMessage, leaky=True)
80+
81+
async def initialize(self) -> None:
82+
self.STATE.received_seqs = []
83+
84+
@ez.subscriber(INPUT)
85+
async def on_message(self, msg: TimestampedMessage) -> None:
86+
now = asyncio.get_event_loop().time()
87+
latency_ms = (now - msg.created_at) * 1000
88+
89+
self.STATE.received_count += 1
90+
self.STATE.total_latency += latency_ms
91+
self.STATE.received_seqs.append(msg.seq)
92+
93+
print(
94+
f"[Subscriber] Processing seq={msg.seq:3d}, latency={latency_ms:6.0f}ms",
95+
flush=True,
96+
)
97+
98+
# Simulate slow processing
99+
await asyncio.sleep(self.SETTINGS.process_time_sec)
100+
101+
# Terminate after receiving the last message
102+
if msg.seq == self.SETTINGS.expected_messages - 1:
103+
raise ez.NormalTermination
104+
105+
async def shutdown(self) -> None:
106+
dropped = self.SETTINGS.expected_messages - self.STATE.received_count
107+
avg_latency = (
108+
self.STATE.total_latency / self.STATE.received_count
109+
if self.STATE.received_count > 0
110+
else 0
111+
)
112+
113+
print("\n" + "=" * 60, flush=True)
114+
print("LEAKY SUBSCRIBER SUMMARY", flush=True)
115+
print("=" * 60, flush=True)
116+
print(f" Messages published: {self.SETTINGS.expected_messages}", flush=True)
117+
print(f" Messages received: {self.STATE.received_count}", flush=True)
118+
print(f" Messages dropped: {dropped}", flush=True)
119+
print(f" Sequences received: {self.STATE.received_seqs}", flush=True)
120+
print(f" Average latency: {avg_latency:.0f}ms", flush=True)
121+
print("=" * 60, flush=True)
122+
print(
123+
"\nNote: With leaky=True, the subscriber drops old messages to stay\n"
124+
" current. Without it, backpressure would slow the publisher.",
125+
flush=True,
126+
)
127+
128+
129+
class LeakyDemo(ez.Collection):
130+
"""Demo system with a fast publisher and slow leaky subscriber."""
131+
132+
SETTINGS = FastPublisherSettings
133+
134+
PUB = FastPublisher()
135+
SUB = SlowSubscriber()
136+
137+
def configure(self) -> None:
138+
num_msgs = self.SETTINGS.num_messages
139+
self.PUB.apply_settings(
140+
FastPublisherSettings(
141+
num_messages=num_msgs,
142+
publish_interval_sec=self.SETTINGS.publish_interval_sec,
143+
)
144+
)
145+
self.SUB.apply_settings(
146+
SlowSubscriberSettings(process_time_sec=1.0, expected_messages=num_msgs)
147+
)
148+
149+
def network(self) -> ez.NetworkDefinition:
150+
return ((self.PUB.OUTPUT, self.SUB.INPUT),)
151+
152+
153+
if __name__ == "__main__":
154+
print("Leaky Subscriber Demo", flush=True)
155+
print("=" * 60, flush=True)
156+
print("Publisher: 20 messages at 10 Hz (100ms intervals)", flush=True)
157+
print("Subscriber: Processes at 1 Hz (1000ms per message)", flush=True)
158+
print("Queue: max_queue=3, leaky=True", flush=True)
159+
print("=" * 60, flush=True)
160+
print("\nExpected behavior:", flush=True)
161+
print("- Publisher sends 20 messages over ~2 seconds", flush=True)
162+
print("- Subscriber can only process ~1 message per second", flush=True)
163+
print("- Many messages will be dropped to keep subscriber current", flush=True)
164+
print("=" * 60 + "\n", flush=True)
165+
166+
settings = FastPublisherSettings(num_messages=20, publish_interval_sec=0.1)
167+
system = LeakyDemo(settings)
168+
ez.run(DEMO=system)

src/ezmsg/core/backendprocess.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from .stream import Stream, InputStream, OutputStream
2222
from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR, ZERO_COPY_ATTR
23+
from .messagechannel import LeakyQueue
2324

2425
from .graphcontext import GraphContext
2526
from .pubclient import Publisher
@@ -201,7 +202,12 @@ async def setup_state():
201202
if isinstance(stream, InputStream):
202203
logger.debug(f"Creating Subscriber from {stream}")
203204
sub = asyncio.run_coroutine_threadsafe(
204-
context.subscriber(stream.address), loop
205+
context.subscriber(
206+
stream.address,
207+
leaky=stream.leaky,
208+
max_queue=stream.max_queue,
209+
),
210+
loop,
205211
).result()
206212
task_name = f"SUBSCRIBER|{stream.address}"
207213
coro_callables[task_name] = partial(
@@ -406,12 +412,19 @@ async def handle_subscriber(
406412
:param callables: Set of async callables to invoke with messages.
407413
:type callables: set[Callable[..., Coroutine[Any, Any, None]]]
408414
"""
415+
# Leaky subscribers use recv() to copy and release backpressure immediately,
416+
# allowing publishers to continue without blocking during slow processing.
417+
# Non-leaky subscribers use recv_zero_copy() to hold backpressure during
418+
# processing, which provides zero-copy performance but applies backpressure.
419+
409420
while True:
410421
if not callables:
411422
sub.close()
412423
await sub.wait_closed()
413424
break
414-
async with sub.recv_zero_copy() as msg:
425+
426+
if sub.leaky:
427+
msg = await sub.recv()
415428
try:
416429
for callable in list(callables):
417430
try:
@@ -420,6 +433,16 @@ async def handle_subscriber(
420433
callables.remove(callable)
421434
finally:
422435
del msg
436+
else:
437+
async with sub.recv_zero_copy() as msg:
438+
try:
439+
for callable in list(callables):
440+
try:
441+
await callable(msg)
442+
except (Complete, NormalTermination):
443+
callables.remove(callable)
444+
finally:
445+
del msg
423446

424447
if len(callables) > 1:
425448
await asyncio.sleep(0)

src/ezmsg/core/messagechannel.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,48 @@
2525
logger = logging.getLogger("ezmsg")
2626

2727

28-
NotificationQueue = asyncio.Queue[typing.Tuple[UUID, int]]
28+
class LeakyQueue(asyncio.Queue[typing.Tuple[UUID, int]]):
29+
"""
30+
An asyncio.Queue that drops oldest items when full.
31+
32+
When putting a new item into a full queue, the oldest item is
33+
dropped to make room.
34+
35+
:param maxsize: Maximum queue size (must be positive)
36+
:param on_drop: Optional callback called with dropped item when dropping
37+
"""
38+
39+
def __init__(
40+
self,
41+
maxsize: int,
42+
on_drop: typing.Callable[[typing.Any], None] | None = None,
43+
):
44+
super().__init__(maxsize=maxsize)
45+
self._on_drop = on_drop
46+
47+
def _drop_oldest(self) -> None:
48+
"""Drop the oldest item from the queue, calling on_drop if set."""
49+
try:
50+
dropped = self.get_nowait()
51+
if self._on_drop is not None:
52+
self._on_drop(dropped)
53+
except asyncio.QueueEmpty:
54+
pass
55+
56+
async def put(self, item: typing.Tuple[UUID, int]) -> None:
57+
"""Put an item into the queue, dropping oldest if full."""
58+
if self.full():
59+
self._drop_oldest()
60+
await super().put(item)
61+
62+
def put_nowait(self, item: typing.Tuple[UUID, int]) -> None:
63+
"""Put an item without blocking, dropping oldest if full."""
64+
if self.full():
65+
self._drop_oldest()
66+
super().put_nowait(item)
67+
68+
69+
NotificationQueue = asyncio.Queue[typing.Tuple[UUID, int]] | LeakyQueue
2970

3071

3172
class Channel:
@@ -123,6 +164,8 @@ async def create(
123164
writer.write(Command.CHANNEL.value)
124165
writer.write(encode_str(id_str))
125166

167+
topic = await read_str(reader)
168+
126169
shm = None
127170
shm_name = await read_str(reader)
128171
try:
@@ -143,6 +186,7 @@ async def create(
143186
assert num_buffers > 0, "publisher reports invalid num_buffers"
144187

145188
chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address, _guard=cls._SENTINEL)
189+
chan.topic = topic
146190

147191
chan._graph_task = asyncio.create_task(
148192
chan._graph_connection(graph_reader, graph_writer),
@@ -310,16 +354,41 @@ def get(
310354
try:
311355
yield self.cache[msg_id]
312356
finally:
313-
buf_idx = msg_id % self.num_buffers
314-
self.backpressure.free(client_id, buf_idx)
315-
if self.backpressure.buffers[buf_idx].is_empty:
316-
self.cache.release(msg_id)
317-
318-
# If pub is in same process as this channel, avoid TCP
319-
if self._local_backpressure is not None:
320-
self._local_backpressure.free(self.id, buf_idx)
321-
else:
322-
self._acknowledge(msg_id)
357+
self._release_backpressure(msg_id, client_id)
358+
359+
def release_without_get(self, msg_id: int, client_id: UUID) -> None:
360+
"""
361+
Release backpressure for a message without retrieving it.
362+
363+
Used by leaky subscribers when dropping notifications to ensure
364+
backpressure is properly released for messages that will never be read.
365+
366+
:param msg_id: Message ID to release
367+
:type msg_id: int
368+
:param client_id: UUID of client releasing this message
369+
:type client_id: UUID
370+
"""
371+
self._release_backpressure(msg_id, client_id)
372+
373+
def _release_backpressure(self, msg_id: int, client_id: UUID) -> None:
374+
"""
375+
Internal method to release backpressure for a message.
376+
377+
:param msg_id: Message ID to release
378+
:type msg_id: int
379+
:param client_id: UUID of client releasing this message
380+
:type client_id: UUID
381+
"""
382+
buf_idx = msg_id % self.num_buffers
383+
self.backpressure.free(client_id, buf_idx)
384+
if self.backpressure.buffers[buf_idx].is_empty:
385+
self.cache.release(msg_id)
386+
387+
# If pub is in same process as this channel, avoid TCP
388+
if self._local_backpressure is not None:
389+
self._local_backpressure.free(self.id, buf_idx)
390+
else:
391+
self._acknowledge(msg_id)
323392

324393
def _acknowledge(self, msg_id: int) -> None:
325394
try:

src/ezmsg/core/pubclient.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ async def _channel_connect(
329329
if cmd == Command.CHANNEL.value:
330330
channel_id_str = await read_str(reader)
331331
channel_id = UUID(channel_id_str)
332+
writer.write(encode_str(self.topic))
332333
writer.write(encode_str(self._shm.name))
333334
shm_ok = await reader.read(1) == Command.SHM_OK.value
334335
pid = await read_int(reader)

0 commit comments

Comments
 (0)