88from asyncio .exceptions import CancelledError
99from collections import deque
1010from collections .abc import AsyncIterable
11- from contextlib import AbstractAsyncContextManager
1211from typing import TYPE_CHECKING , Concatenate , Generic , cast
13- from typing_extensions import Any , ParamSpec , Protocol , TypeVar , final , override
12+ from typing_extensions import Any , ParamSpec , TypeVar , final , override
1413
1514from duron ._core .ops import (
1615 Barrier ,
2120 StreamEmit ,
2221 create_op ,
2322)
24- from duron .loop import EventLoop , wrap_future
23+ from duron .loop import EventLoop , LoopClosedError , wrap_future
2524
2625if TYPE_CHECKING :
2726 from collections .abc import AsyncGenerator , Awaitable , Callable , Sequence
@@ -63,32 +62,10 @@ def reason(self) -> Exception | None:
6362 return cast ("Exception | None" , self .__cause__ )
6463
6564
66- class StreamWriter (
67- AbstractAsyncContextManager ["StreamWriter[_T_contra]" ], Protocol , Generic [_T_contra ]
68- ):
65+ @final
66+ class StreamWriter (Generic [_T_contra ]):
6967 """Protocol for writing values to a stream."""
7068
71- async def send (self , value : _T_contra , / ) -> None :
72- """Send a value to the stream.
73-
74- Args:
75- value: The value to send to stream consumers.
76-
77- """
78- ...
79-
80- async def close (self , error : Exception | None = None , / ) -> None :
81- """Close the stream, optionally with an error.
82-
83- Args:
84- error: Optional exception to signal an error condition to consumers.
85-
86- """
87- ...
88-
89-
90- @final
91- class OpWriter (Generic [_T_contra ]):
9269 __slots__ = ("_closed" , "_loop" , "_stream_id" )
9370
9471 def __init__ (self , stream_id : str , loop : EventLoop ) -> None :
@@ -97,11 +74,35 @@ def __init__(self, stream_id: str, loop: EventLoop) -> None:
9774 self ._closed = False
9875
9976 async def send (self , value : _T_contra , / ) -> None :
77+ """Send a value to the stream.
78+
79+ Raises:
80+ RuntimeError: If the stream is already closed.
81+
82+ Args:
83+ value: The value to send to stream consumers.
84+
85+ """
86+ if self ._closed :
87+ msg = "Cannot send to a closed stream"
88+ raise RuntimeError (msg )
10089 await wrap_future (
10190 create_op (self ._loop , StreamEmit (stream_id = self ._stream_id , value = value ))
10291 )
10392
10493 async def close (self , exception : Exception | None = None , / ) -> None :
94+ """Close the stream, optionally with an error.
95+
96+ Raises:
97+ RuntimeError: If the stream is already closed.
98+
99+ Args:
100+ exception: Optional exception to signal an error condition to consumers.
101+
102+ """
103+ if self ._closed :
104+ msg = "Cannot send to a closed stream"
105+ raise RuntimeError (msg )
105106 await wrap_future (
106107 create_op (
107108 self ._loop , StreamClose (stream_id = self ._stream_id , exception = exception )
@@ -120,14 +121,15 @@ async def __aexit__(
120121 ) -> None :
121122 if self ._closed :
122123 return
123- if not exc_value :
124- await self .close ()
125- elif isinstance (exc_value , Exception ):
126- await self .close (exc_value )
127- else :
128- await self .close (
129- Exception (f"StreamWriter exited with exception: { exc_value } " )
130- )
124+ with contextlib .suppress (LoopClosedError ):
125+ if not exc_value :
126+ await self .close ()
127+ elif isinstance (exc_value , Exception ):
128+ await self .close (exc_value )
129+ else :
130+ await self .close (
131+ Exception (f"StreamWriter exited with exception: { exc_value } " )
132+ )
131133
132134
133135class Stream (ABC , AsyncIterable [_T ], Generic [_T ]):
@@ -189,7 +191,7 @@ async def create_stream(
189191 sid = await create_op (
190192 loop , StreamCreate (dtype = dtype , observer = w , name = name , metadata = metadata )
191193 )
192- writer : OpWriter [_T ] = OpWriter (sid , loop )
194+ writer : StreamWriter [_T ] = StreamWriter (sid , loop )
193195 return (s , writer )
194196
195197
@@ -206,7 +208,7 @@ def __init__(self) -> None:
206208 self ._loop : asyncio .AbstractEventLoop | None = None
207209 self ._event : asyncio .Event | None = None
208210 self ._buffer : deque [tuple [int , _T | StreamClosed ]] = deque ()
209- self ._cursor : int = - 1
211+ self ._cursor : int = 0
210212
211213 @final
212214 @override
@@ -216,34 +218,41 @@ async def next(self, *, block: bool) -> Sequence[_T]:
216218 self ._event = asyncio .Event ()
217219
218220 if not block :
219- return await self ._next_nowait ()
220-
221- while True :
222- _ = await self ._event .wait ()
223221 self ._event .clear ()
224- if it := await self ._next_nowait ():
225- return it
226-
227- async def _next_nowait (self ) -> Sequence [_T ]:
228- if not self ._loop :
229- self ._loop = asyncio .get_running_loop ()
222+ begin , end = await self ._next_cursor ()
223+ return self ._pop (begin , end )
230224
231- if isinstance (self ._loop , EventLoop ):
232-
233- def cb (f : asyncio .Future [tuple [int , int ]]) -> None :
234- if not f .cancelled ():
235- offset , _ = f .result ()
236- self ._cursor = max (self ._cursor , offset )
237-
238- begin = self ._cursor
239- op = create_op (self ._loop , Barrier ())
240- op .add_done_callback (cb )
241- end , _ = await asyncio .shield (op )
242- self ._cursor = max (self ._cursor , end )
243- else :
225+ while True :
226+ try :
227+ _ = await self ._event .wait ()
228+ finally :
229+ self ._event .clear ()
230+
231+ begin , end = await self ._next_cursor ()
232+ items = self ._pop (begin , end )
233+ if items :
234+ return items
235+
236+ async def _next_cursor (self ) -> tuple [int , int | None ]:
237+ if not isinstance (self ._loop , EventLoop ):
238+ return (0 , None )
239+
240+ def cb (f : asyncio .Future [tuple [int , int ]]) -> None :
241+ if not f .cancelled ():
242+ offset , _ = f .result ()
243+ self ._cursor = max (self ._cursor , offset )
244+
245+ begin = self ._cursor
246+ op = create_op (self ._loop , Barrier ())
247+ op .add_done_callback (cb )
248+ end , _ = await asyncio .shield (op )
249+ self ._cursor = max (self ._cursor , end )
250+ return (begin , end )
251+
252+ def _pop (self , begin : int , end : int | None ) -> Sequence [_T ]:
253+ if end is None :
244254 if not self ._buffer :
245255 return ()
246- begin = 0
247256 end = self ._buffer [- 1 ][0 ] + 1
248257
249258 result : list [_T ] = []
@@ -328,7 +337,7 @@ async def run_stateful(
328337 stream : _StatefulStream [_U , _T ] = _StatefulStream (
329338 reducer , fn , initial , * args , ** kwargs
330339 )
331- sink : StreamWriter [_U ] = OpWriter (
340+ sink : StreamWriter [_U ] = StreamWriter (
332341 await create_op (
333342 loop ,
334343 StreamCreate (
0 commit comments