Skip to content

Commit bfccbc1

Browse files
committed
feat: handle more stream edge-cases
1 parent d7f5728 commit bfccbc1

File tree

6 files changed

+157
-73
lines changed

6 files changed

+157
-73
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ preview = true
7979
future-annotations = true
8080

8181
[tool.ruff.lint.per-file-ignores]
82-
"tests/**/*.py" = ["S101", "S311", "PLC2701", "PLR2004", "RUF029"]
82+
"tests/**/*.py" = ["S311", "PLC2701", "PLR2004", "RUF029"]
8383
"examples/**/*.py" = ["T201", "S311"]
8484

8585
[tool.ruff.lint.isort]
8686
extra-standard-library = ["typing_extensions"]
87-
split-on-trailing-comma=false
87+
split-on-trailing-comma = false
8888

8989
[tool.ruff.lint.flake8-type-checking]
9090
runtime-evaluated-base-classes = ["typing_extensions.TypedDict"]

src/duron/_core/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
create_op,
3232
)
3333
from duron._core.signal import Signal
34-
from duron._core.stream import OpWriter, Stream, StreamWriter, create_buffer_stream
34+
from duron._core.stream import Stream, StreamWriter, create_buffer_stream
3535
from duron._core.stream_manager import StreamManager
3636
from duron._core.task_manager import TaskError, TaskManager
3737
from duron._core.utils import decode_error, encode_error
@@ -754,7 +754,7 @@ async def open_stream(
754754
return self._streams.pop(name)
755755

756756
sid = await self._stream_manager.wait_stream(name)
757-
w: OpWriter[Any] = OpWriter(sid, self._loop)
757+
w: StreamWriter[Any] = StreamWriter(sid, self._loop)
758758
return w
759759

760760
def is_future_pending(self, future_id: str) -> bool:

src/duron/_core/signal.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
from typing_extensions import Any, TypeVar, final, override
88

99
from duron._core.ops import StreamCreate, create_op
10-
from duron._core.stream import OpWriter
10+
from duron._core.stream import StreamWriter
1111

1212
if TYPE_CHECKING:
1313
from types import TracebackType
1414

1515
from duron._core.ops import OpMetadata
16-
from duron._core.stream import StreamWriter
1716
from duron.loop import EventLoop
1817
from duron.typing import TypeHint
1918

@@ -120,5 +119,5 @@ async def create_signal(
120119
sid = await create_op(
121120
loop, StreamCreate(dtype=dtype, name=name, observer=s, metadata=metadata)
122121
)
123-
w: OpWriter[_T] = OpWriter(sid, loop)
122+
w: StreamWriter[_T] = StreamWriter(sid, loop)
124123
return (s, w)

src/duron/_core/stream.py

Lines changed: 71 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
from asyncio.exceptions import CancelledError
99
from collections import deque
1010
from collections.abc import AsyncIterable
11-
from contextlib import AbstractAsyncContextManager
1211
from typing import TYPE_CHECKING, Concatenate, Generic, cast
13-
from typing_extensions import Any, ParamSpec, Protocol, TypeVar, final, override
12+
from typing_extensions import Any, ParamSpec, TypeVar, final, override
1413

1514
from duron._core.ops import (
1615
Barrier,
@@ -21,7 +20,7 @@
2120
StreamEmit,
2221
create_op,
2322
)
24-
from duron.loop import EventLoop, wrap_future
23+
from duron.loop import EventLoop, LoopClosedError, wrap_future
2524

2625
if TYPE_CHECKING:
2726
from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence
@@ -63,32 +62,10 @@ def reason(self) -> Exception | None:
6362
return cast("Exception | None", self.__cause__)
6463

6564

66-
class StreamWriter(
67-
AbstractAsyncContextManager["StreamWriter[_T_contra]"], Protocol, Generic[_T_contra]
68-
):
65+
@final
66+
class StreamWriter(Generic[_T_contra]):
6967
"""Protocol for writing values to a stream."""
7068

71-
async def send(self, value: _T_contra, /) -> None:
72-
"""Send a value to the stream.
73-
74-
Args:
75-
value: The value to send to stream consumers.
76-
77-
"""
78-
...
79-
80-
async def close(self, error: Exception | None = None, /) -> None:
81-
"""Close the stream, optionally with an error.
82-
83-
Args:
84-
error: Optional exception to signal an error condition to consumers.
85-
86-
"""
87-
...
88-
89-
90-
@final
91-
class OpWriter(Generic[_T_contra]):
9269
__slots__ = ("_closed", "_loop", "_stream_id")
9370

9471
def __init__(self, stream_id: str, loop: EventLoop) -> None:
@@ -97,11 +74,35 @@ def __init__(self, stream_id: str, loop: EventLoop) -> None:
9774
self._closed = False
9875

9976
async def send(self, value: _T_contra, /) -> None:
77+
"""Send a value to the stream.
78+
79+
Raises:
80+
RuntimeError: If the stream is already closed.
81+
82+
Args:
83+
value: The value to send to stream consumers.
84+
85+
"""
86+
if self._closed:
87+
msg = "Cannot send to a closed stream"
88+
raise RuntimeError(msg)
10089
await wrap_future(
10190
create_op(self._loop, StreamEmit(stream_id=self._stream_id, value=value))
10291
)
10392

10493
async def close(self, exception: Exception | None = None, /) -> None:
94+
"""Close the stream, optionally with an error.
95+
96+
Raises:
97+
RuntimeError: If the stream is already closed.
98+
99+
Args:
100+
exception: Optional exception to signal an error condition to consumers.
101+
102+
"""
103+
if self._closed:
104+
msg = "Cannot send to a closed stream"
105+
raise RuntimeError(msg)
105106
await wrap_future(
106107
create_op(
107108
self._loop, StreamClose(stream_id=self._stream_id, exception=exception)
@@ -120,14 +121,15 @@ async def __aexit__(
120121
) -> None:
121122
if self._closed:
122123
return
123-
if not exc_value:
124-
await self.close()
125-
elif isinstance(exc_value, Exception):
126-
await self.close(exc_value)
127-
else:
128-
await self.close(
129-
Exception(f"StreamWriter exited with exception: {exc_value}")
130-
)
124+
with contextlib.suppress(LoopClosedError):
125+
if not exc_value:
126+
await self.close()
127+
elif isinstance(exc_value, Exception):
128+
await self.close(exc_value)
129+
else:
130+
await self.close(
131+
Exception(f"StreamWriter exited with exception: {exc_value}")
132+
)
131133

132134

133135
class Stream(ABC, AsyncIterable[_T], Generic[_T]):
@@ -189,7 +191,7 @@ async def create_stream(
189191
sid = await create_op(
190192
loop, StreamCreate(dtype=dtype, observer=w, name=name, metadata=metadata)
191193
)
192-
writer: OpWriter[_T] = OpWriter(sid, loop)
194+
writer: StreamWriter[_T] = StreamWriter(sid, loop)
193195
return (s, writer)
194196

195197

@@ -206,7 +208,7 @@ def __init__(self) -> None:
206208
self._loop: asyncio.AbstractEventLoop | None = None
207209
self._event: asyncio.Event | None = None
208210
self._buffer: deque[tuple[int, _T | StreamClosed]] = deque()
209-
self._cursor: int = -1
211+
self._cursor: int = 0
210212

211213
@final
212214
@override
@@ -216,34 +218,41 @@ async def next(self, *, block: bool) -> Sequence[_T]:
216218
self._event = asyncio.Event()
217219

218220
if not block:
219-
return await self._next_nowait()
220-
221-
while True:
222-
_ = await self._event.wait()
223221
self._event.clear()
224-
if it := await self._next_nowait():
225-
return it
226-
227-
async def _next_nowait(self) -> Sequence[_T]:
228-
if not self._loop:
229-
self._loop = asyncio.get_running_loop()
222+
begin, end = await self._next_cursor()
223+
return self._pop(begin, end)
230224

231-
if isinstance(self._loop, EventLoop):
232-
233-
def cb(f: asyncio.Future[tuple[int, int]]) -> None:
234-
if not f.cancelled():
235-
offset, _ = f.result()
236-
self._cursor = max(self._cursor, offset)
237-
238-
begin = self._cursor
239-
op = create_op(self._loop, Barrier())
240-
op.add_done_callback(cb)
241-
end, _ = await asyncio.shield(op)
242-
self._cursor = max(self._cursor, end)
243-
else:
225+
while True:
226+
try:
227+
_ = await self._event.wait()
228+
finally:
229+
self._event.clear()
230+
231+
begin, end = await self._next_cursor()
232+
items = self._pop(begin, end)
233+
if items:
234+
return items
235+
236+
async def _next_cursor(self) -> tuple[int, int | None]:
237+
if not isinstance(self._loop, EventLoop):
238+
return (0, None)
239+
240+
def cb(f: asyncio.Future[tuple[int, int]]) -> None:
241+
if not f.cancelled():
242+
offset, _ = f.result()
243+
self._cursor = max(self._cursor, offset)
244+
245+
begin = self._cursor
246+
op = create_op(self._loop, Barrier())
247+
op.add_done_callback(cb)
248+
end, _ = await asyncio.shield(op)
249+
self._cursor = max(self._cursor, end)
250+
return (begin, end)
251+
252+
def _pop(self, begin: int, end: int | None) -> Sequence[_T]:
253+
if end is None:
244254
if not self._buffer:
245255
return ()
246-
begin = 0
247256
end = self._buffer[-1][0] + 1
248257

249258
result: list[_T] = []
@@ -328,7 +337,7 @@ async def run_stateful(
328337
stream: _StatefulStream[_U, _T] = _StatefulStream(
329338
reducer, fn, initial, *args, **kwargs
330339
)
331-
sink: StreamWriter[_U] = OpWriter(
340+
sink: StreamWriter[_U] = StreamWriter(
332341
await create_op(
333342
loop,
334343
StreamCreate(

src/duron/loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ async def shutdown_asyncgens(self) -> None:
354354
pass
355355

356356
@override
357-
async def shutdown_default_executor(self) -> None:
357+
async def shutdown_default_executor(self, timeout: float | None = None) -> None:
358358
pass
359359

360360
def _timer_handle_cancelled(self, _th: asyncio.TimerHandle) -> None:

tests/test_stream.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@
1010

1111
from duron import (
1212
Context,
13+
Provided,
1314
Reducer,
1415
Session,
16+
Signal,
1517
SignalInterrupt,
18+
Stream,
1619
StreamClosed,
1720
durable,
1821
effect,
1922
)
2023
from duron.contrib.storage import MemoryLogStorage
24+
from duron.loop import LoopClosedError
2125

2226
if TYPE_CHECKING:
2327
from duron import StreamWriter
@@ -254,7 +258,79 @@ async def trigger_signal() -> None:
254258
result2 = await (await t.start(activity)).result()
255259
b = len(await log.entries())
256260
for chunk1, chunk2 in zip(result1, result2, strict=True):
257-
length = min(len(chunk1), len(chunk2))
258-
assert len(chunk2) >= len(chunk1)
259-
assert chunk1[:length] == chunk2[:length]
261+
assert len(chunk2) == len(chunk1)
262+
assert chunk1 == chunk2
263+
assert a == b
264+
265+
266+
@pytest.mark.asyncio
267+
async def test_external_stream_signal_timing() -> None:
268+
@durable()
269+
async def activity(
270+
_ctx: Context,
271+
input_stream: Stream[int] = Provided,
272+
interrupt_signal: Signal[str] = Provided,
273+
) -> list[list[list[int]]]:
274+
results: list[list[list[int]]] = []
275+
276+
# Consume with signal context
277+
for i in range(1000):
278+
batch: list[list[int]] = []
279+
try:
280+
async with interrupt_signal:
281+
# Consume values until interrupted
282+
while True:
283+
if i % 3 == 0:
284+
batch.append(list(await input_stream.next(block=False)))
285+
await asyncio.sleep(0.01)
286+
else:
287+
batch.append(list(await input_stream.next(block=True)))
288+
except StreamClosed:
289+
results.append(batch)
290+
except SignalInterrupt as e:
291+
results.append(batch)
292+
if e.value is True:
293+
break
294+
return results
295+
296+
# First run
297+
log = MemoryLogStorage()
298+
async with Session(log) as sess:
299+
run = await sess.start(activity)
300+
input_stream = await run.open_stream("input_stream", "w")
301+
interrupt_signal = await run.open_stream("interrupt_signal", "w")
302+
303+
# Emit values continuously
304+
async def emitter() -> None:
305+
async with input_stream as w:
306+
for i in range(500):
307+
try:
308+
await w.send(i)
309+
except LoopClosedError:
310+
break
311+
await asyncio.sleep(random.random() * 0.001)
312+
313+
# Trigger signal at specific time
314+
async def trigger_signal() -> None:
315+
async with interrupt_signal as w:
316+
for i in range(5):
317+
await asyncio.sleep(0.015)
318+
await w.send(i == 4)
319+
320+
result1, _, _ = await asyncio.gather(
321+
run.result(),
322+
asyncio.create_task(emitter()),
323+
asyncio.create_task(trigger_signal()),
324+
)
325+
326+
# Replay
327+
a = len(await log.entries())
328+
async with Session(log) as sess:
329+
result2 = await (await sess.start(activity)).result()
330+
b = len(await log.entries())
331+
332+
# Verify results match
333+
for chunk1, chunk2 in zip(result1, result2, strict=True):
334+
assert len(chunk2) == len(chunk1)
335+
assert chunk1 == chunk2
260336
assert a == b

0 commit comments

Comments
 (0)