diff --git a/python-client/.gitignore b/python-client/.gitignore new file mode 100644 index 00000000..befa7815 --- /dev/null +++ b/python-client/.gitignore @@ -0,0 +1,9 @@ +.venv/ +__pycache__/ +*.pyc +*.egg-info/ +.pytest_cache/ +dist/ +build/ +# esbuild build artifact (built from test_server.ts at test time) +tests/test_server.mjs diff --git a/python-client/pyproject.toml b/python-client/pyproject.toml new file mode 100644 index 00000000..9d681b81 --- /dev/null +++ b/python-client/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "river-client" +version = "0.1.0" +description = "Python client for River protocol v2.0" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "MIT"} +dependencies = [ + "websockets>=12.0", + "msgpack>=1.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] + +[tool.setuptools.packages.find] +include = ["river*"] diff --git a/python-client/river/__init__.py b/python-client/river/__init__.py new file mode 100644 index 00000000..01b30d92 --- /dev/null +++ b/python-client/river/__init__.py @@ -0,0 +1,19 @@ +"""River protocol v2.0 Python client implementation.""" + +from river.types import TransportMessage, Ok, Err +from river.codec import NaiveJsonCodec, BinaryCodec +from river.transport import WebSocketClientTransport +from river.client import RiverClient +from river.streams import Readable, Writable + +__all__ = [ + "RiverClient", + "WebSocketClientTransport", + "NaiveJsonCodec", + "BinaryCodec", + "TransportMessage", + "Ok", + "Err", + "Readable", + "Writable", +] diff --git a/python-client/river/client.py b/python-client/river/client.py new file mode 100644 index 00000000..eb6210d3 --- /dev/null +++ b/python-client/river/client.py @@ -0,0 +1,468 @@ +"""River client for invoking remote procedures. + +Provides the high-level API for calling rpc, stream, upload, and +subscription procedures on a River server. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from typing import Any, Callable + +from river.streams import Readable, Writable +from river.transport import WebSocketClientTransport +from river.types import ( + ControlFlags, + PartialTransportMessage, + TransportMessage, + cancel_message, + close_stream_message, + err_result, + generate_id, + is_ack, + is_stream_cancel, + is_stream_close, + CANCEL_CODE, + UNEXPECTED_DISCONNECT_CODE, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class RpcResult: + """Result of an RPC call.""" + + ok: bool + payload: Any + + +@dataclass +class StreamResult: + """Result of opening a stream procedure.""" + + req_writable: Writable + res_readable: Readable + + +@dataclass +class UploadResult: + """Result of opening an upload procedure.""" + + req_writable: Writable + finalize: Callable[[], Any] # async callable returning RpcResult + + +@dataclass +class SubscriptionResult: + """Result of opening a subscription procedure.""" + + res_readable: Readable + + +class RiverClient: + """Client for invoking procedures on a River server. + + Usage: + transport = WebSocketClientTransport("ws://localhost:8080", ...) + client = RiverClient(transport, server_id="my-server") + + # RPC + result = await client.rpc("service", "procedure", {"arg": 1}) + + # Stream + stream = client.stream("service", "procedure", {"arg": 1}) + stream.req_writable.write({"data": "hello"}) + async for msg in stream.res_readable: + print(msg) + + # Upload + upload = client.upload("service", "procedure", {"arg": 1}) + upload.req_writable.write({"data": "chunk1"}) + upload.req_writable.close() + result = await upload.finalize() + + # Subscription + sub = client.subscribe("service", "procedure", {"arg": 1}) + async for msg in sub.res_readable: + print(msg) + """ + + def __init__( + self, + transport: WebSocketClientTransport, + server_id: str | None = None, + connect_on_invoke: bool = True, + eagerly_connect: bool = False, + ) -> None: + self._transport = transport + self._server_id = server_id or transport.server_id + self._connect_on_invoke = connect_on_invoke + + if eagerly_connect: + transport.connect(self._server_id) + + @property + def transport(self) -> WebSocketClientTransport: + return self._transport + + async def rpc( + self, + service_name: str, + procedure_name: str, + init: Any, + abort_signal: asyncio.Event | None = None, + ) -> dict[str, Any]: + """Invoke an RPC procedure. + + Returns the result dict: {"ok": True/False, "payload": ...} + """ + result = self._handle_proc( + proc_type="rpc", + service_name=service_name, + procedure_name=procedure_name, + init=init, + abort_signal=abort_signal, + ) + # For RPC, we await the single response + readable = result["res_readable"] + done, value = await readable.next() + if done: + return err_result( + UNEXPECTED_DISCONNECT_CODE, "No response received" + ) + return value + + def stream( + self, + service_name: str, + procedure_name: str, + init: Any, + abort_signal: asyncio.Event | None = None, + ) -> StreamResult: + """Open a stream procedure. + + Returns StreamResult with req_writable and res_readable. + """ + result = self._handle_proc( + proc_type="stream", + service_name=service_name, + procedure_name=procedure_name, + init=init, + abort_signal=abort_signal, + ) + return StreamResult( + req_writable=result["req_writable"], + res_readable=result["res_readable"], + ) + + def upload( + self, + service_name: str, + procedure_name: str, + init: Any, + abort_signal: asyncio.Event | None = None, + ) -> UploadResult: + """Open an upload procedure. + + Returns UploadResult with req_writable and finalize(). + """ + result = self._handle_proc( + proc_type="upload", + service_name=service_name, + procedure_name=procedure_name, + init=init, + abort_signal=abort_signal, + ) + + async def finalize() -> dict[str, Any]: + readable = result["res_readable"] + done, value = await readable.next() + if done: + return err_result( + UNEXPECTED_DISCONNECT_CODE, "No response received" + ) + return value + + return UploadResult( + req_writable=result["req_writable"], + finalize=finalize, + ) + + def subscribe( + self, + service_name: str, + procedure_name: str, + init: Any, + abort_signal: asyncio.Event | None = None, + ) -> SubscriptionResult: + """Open a subscription procedure. + + Returns SubscriptionResult with res_readable. + """ + result = self._handle_proc( + proc_type="subscription", + service_name=service_name, + procedure_name=procedure_name, + init=init, + abort_signal=abort_signal, + ) + return SubscriptionResult(res_readable=result["res_readable"]) + + def _handle_proc( + self, + proc_type: str, + service_name: str, + procedure_name: str, + init: Any, + abort_signal: asyncio.Event | None = None, + ) -> dict[str, Any]: + """Core procedure dispatch logic. + + Sets up the stream, registers message handlers, sends the init message. + """ + to = self._server_id + transport = self._transport + + # If transport is closed, return immediate disconnect error + if transport.get_status() != "open": + res_readable = Readable() + res_readable._push_value( + err_result( + UNEXPECTED_DISCONNECT_CODE, "transport is closed" + ) + ) + res_readable._trigger_close() + req_writable = Writable(write_cb=lambda _: None, close_cb=None) + req_writable._closed = True + return { + "res_readable": res_readable, + "req_writable": req_writable, + } + + # Connect if needed + if self._connect_on_invoke: + transport.connect(to) + + # Get the session and a send function + session = transport._get_or_create_session(to) + session_id = session.id + try: + send_fn = transport.get_session_bound_send_fn(to, session_id) + except RuntimeError: + # Session already dead + res_readable = Readable() + res_readable._push_value( + err_result( + UNEXPECTED_DISCONNECT_CODE, + f"{to} unexpectedly disconnected", + ) + ) + res_readable._trigger_close() + req_writable = Writable(write_cb=lambda _: None, close_cb=None) + req_writable._closed = True + return { + "res_readable": res_readable, + "req_writable": req_writable, + } + + # Determine flags + proc_closes_with_init = proc_type in ("rpc", "subscription") + stream_id = generate_id() + + # Create readable for responses + res_readable: Readable = Readable() + + # Tracking state + clean_close = True + cleaned_up = False + + def cleanup(): + nonlocal cleaned_up + if cleaned_up: + return + cleaned_up = True + transport.remove_event_listener("message", on_message) + transport.remove_event_listener("sessionStatus", on_session_status) + + def close_readable(): + if not res_readable.is_closed(): + try: + res_readable._trigger_close() + except RuntimeError: + pass + if req_writable.is_closed(): + cleanup() + + # Create writable for requests + def write_cb(raw_value: Any) -> None: + try: + send_fn( + PartialTransportMessage( + payload=raw_value, + stream_id=stream_id, + control_flags=0, + ) + ) + except RuntimeError: + pass + + def close_cb() -> None: + nonlocal clean_close + if not proc_closes_with_init and clean_close: + try: + send_fn(close_stream_message(stream_id)) + except RuntimeError: + pass + if res_readable.is_closed(): + cleanup() + + req_writable: Writable = Writable(write_cb=write_cb, close_cb=close_cb) + + def on_message(msg: TransportMessage) -> None: + nonlocal clean_close + if msg.stream_id != stream_id: + return + if msg.to != transport.client_id: + return + + # Cancel from server + if is_stream_cancel(msg.control_flags): + clean_close = False + payload = msg.payload + if isinstance(payload, dict) and "ok" in payload: + res_readable._push_value(payload) + else: + res_readable._push_value( + err_result( + payload.get("code", "UNKNOWN") if isinstance(payload, dict) else "UNKNOWN", + str(payload), + ) + ) + close_readable() + if req_writable.is_writable(): + req_writable._closed = True + return + + if res_readable.is_closed(): + return + + # Normal payload (not a CLOSE control) + if isinstance(msg.payload, dict): + if msg.payload.get("type") != "CLOSE": + if "ok" in msg.payload: + res_readable._push_value(msg.payload) + + # Stream close + if is_stream_close(msg.control_flags): + close_readable() + + def on_session_status(evt: dict) -> None: + nonlocal clean_close + if evt.get("status") != "closing": + return + event_session = evt.get("session") + if event_session is None: + return + if event_session.to_id != to or event_session.id != session_id: + return + + clean_close = False + try: + res_readable._push_value( + err_result( + UNEXPECTED_DISCONNECT_CODE, + f"{to} unexpectedly disconnected", + ) + ) + except RuntimeError: + pass + close_readable() + if req_writable.is_writable(): + req_writable._closed = True + + def on_client_cancel() -> None: + nonlocal clean_close + clean_close = False + try: + res_readable._push_value( + err_result(CANCEL_CODE, "cancelled by client") + ) + except RuntimeError: + pass + close_readable() + if req_writable.is_writable(): + req_writable._closed = True + try: + send_fn( + cancel_message( + stream_id, + err_result(CANCEL_CODE, "cancelled by client"), + ) + ) + except RuntimeError: + pass + + # Register listeners + transport.add_event_listener("message", on_message) + transport.add_event_listener("sessionStatus", on_session_status) + + # Wire up abort signal + if abort_signal is not None: + # Use asyncio task to watch the event + async def _watch_abort(): + await abort_signal.wait() + on_client_cancel() + + try: + loop = asyncio.get_event_loop() + loop.create_task(_watch_abort()) + except RuntimeError: + pass + + # Send init message + init_flags = ( + ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit + if proc_closes_with_init + else ControlFlags.StreamOpenBit + ) + + try: + send_fn( + PartialTransportMessage( + payload=init, + stream_id=stream_id, + control_flags=init_flags, + service_name=service_name, + procedure_name=procedure_name, + ) + ) + except RuntimeError as e: + # Session dead at send time + try: + res_readable._push_value( + err_result( + UNEXPECTED_DISCONNECT_CODE, + f"{to} unexpectedly disconnected", + ) + ) + res_readable._trigger_close() + except RuntimeError: + pass + req_writable._closed = True + cleanup() + return { + "res_readable": res_readable, + "req_writable": req_writable, + } + + # For rpc/subscription, close request side immediately + if proc_closes_with_init: + req_writable._closed = True + + return { + "res_readable": res_readable, + "req_writable": req_writable, + } diff --git a/python-client/river/codec.py b/python-client/river/codec.py new file mode 100644 index 00000000..3d5e8f11 --- /dev/null +++ b/python-client/river/codec.py @@ -0,0 +1,110 @@ +"""Codec layer for encoding/decoding transport messages.""" + +from __future__ import annotations + +import json +import base64 +from abc import ABC, abstractmethod +from typing import Any + +from river.types import TransportMessage + + +class Codec(ABC): + """Abstract codec for encoding/decoding objects to/from bytes.""" + + @abstractmethod + def to_buffer(self, obj: dict[str, Any]) -> bytes: + """Encode an object to bytes.""" + ... + + @abstractmethod + def from_buffer(self, buf: bytes) -> dict[str, Any]: + """Decode bytes to an object.""" + ... + + +class _CustomEncoder(json.JSONEncoder): + """JSON encoder with support for bytes and large ints.""" + + def default(self, o: Any) -> Any: + if isinstance(o, (bytes, bytearray)): + return {"$t": base64.b64encode(o).decode("ascii")} + return super().default(o) + + +def _custom_object_hook(obj: dict) -> Any: + """JSON decoder hook for custom types.""" + if "$t" in obj and len(obj) == 1: + return base64.b64decode(obj["$t"]) + if "$b" in obj and len(obj) == 1: + return int(obj["$b"]) + return obj + + +class NaiveJsonCodec(Codec): + """Codec using JSON serialization (matches TypeScript NaiveJsonCodec).""" + + name = "naive" + + def to_buffer(self, obj: dict[str, Any]) -> bytes: + return json.dumps(obj, cls=_CustomEncoder, separators=(",", ":")).encode( + "utf-8" + ) + + def from_buffer(self, buf: bytes) -> dict[str, Any]: + return json.loads(buf.decode("utf-8"), object_hook=_custom_object_hook) + + +class BinaryCodec(Codec): + """Codec using msgpack serialization (matches TypeScript BinaryCodec).""" + + name = "binary" + + def to_buffer(self, obj: dict[str, Any]) -> bytes: + import msgpack # type: ignore[import-untyped] + + return msgpack.packb(obj, use_bin_type=True) + + def from_buffer(self, buf: bytes) -> dict[str, Any]: + import msgpack # type: ignore[import-untyped] + + return msgpack.unpackb(buf, raw=False) + + +class CodecMessageAdapter: + """Wraps a Codec with error handling and validation for TransportMessage.""" + + def __init__(self, codec: Codec) -> None: + self._codec = codec + + def to_buffer(self, msg: TransportMessage) -> tuple[bool, bytes | str]: + """Serialize a TransportMessage to bytes. + + Returns (True, bytes) on success, (False, error_reason) on failure. + """ + try: + raw = msg.to_dict() + buf = self._codec.to_buffer(raw) + return True, buf + except Exception as e: + return False, f"Failed to serialize message: {e}" + + def from_buffer(self, buf: bytes) -> tuple[bool, TransportMessage | str]: + """Deserialize bytes to a TransportMessage. + + Returns (True, TransportMessage) on success, (False, error_reason) on failure. + """ + try: + raw = self._codec.from_buffer(buf) + if not isinstance(raw, dict): + return False, f"Expected dict, got {type(raw).__name__}" + # Validate required fields + required = ("id", "from", "to", "seq", "ack", "payload", "streamId") + for field in required: + if field not in raw: + return False, f"Missing required field: {field}" + msg = TransportMessage.from_dict(raw) + return True, msg + except Exception as e: + return False, f"Failed to deserialize message: {e}" diff --git a/python-client/river/session.py b/python-client/river/session.py new file mode 100644 index 00000000..f97daffb --- /dev/null +++ b/python-client/river/session.py @@ -0,0 +1,351 @@ +"""Session state machine for River protocol. + +Manages seq/ack bookkeeping, send buffers, and session lifecycle. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable + +from river.codec import CodecMessageAdapter +from river.types import ( + ControlFlags, + PartialTransportMessage, + TransportMessage, + generate_id, + handshake_request_payload, + heartbeat_message, + is_ack, + PROTOCOL_VERSION, +) + +logger = logging.getLogger(__name__) + + +class SessionState(str, Enum): + """Session state machine states.""" + + NO_CONNECTION = "NoConnection" + BACKING_OFF = "BackingOff" + CONNECTING = "Connecting" + HANDSHAKING = "Handshaking" + CONNECTED = "Connected" + + +@dataclass +class SessionOptions: + """Configuration options for a session.""" + + heartbeat_interval_ms: float = 1000 + heartbeats_until_dead: int = 2 + session_disconnect_grace_ms: float = 5000 + connection_timeout_ms: float = 2000 + handshake_timeout_ms: float = 1000 + enable_transparent_reconnects: bool = True + + +DEFAULT_SESSION_OPTIONS = SessionOptions() + + +class Session: + """Represents a River session with seq/ack bookkeeping and send buffer. + + A session persists across potentially multiple connections, tracking + all the state needed for transparent reconnection. + """ + + def __init__( + self, + session_id: str, + from_id: str, + to_id: str, + codec: CodecMessageAdapter, + options: SessionOptions | None = None, + ) -> None: + self.id = session_id + self.from_id = from_id + self.to_id = to_id + self.codec = codec + self.options = options or DEFAULT_SESSION_OPTIONS + + # Seq/ack bookkeeping + self.seq: int = 0 # Next seq to assign when sending + self.ack: int = 0 # Next expected seq from the other side + self.send_buffer: list[TransportMessage] = [] + + # State machine + self.state: SessionState = SessionState.NO_CONNECTION + + # Connection + self._ws: Any = None # The WebSocket connection + self._is_actively_heartbeating: bool = False + + # Timers + self._heartbeat_task: asyncio.Task | None = None + self._heartbeat_miss_task: asyncio.Task | None = None + self._grace_period_task: asyncio.Task | None = None + self._grace_expiry_time: float | None = None + + # Callbacks + self._on_message: Callable[[TransportMessage], None] | None = None + self._on_connection_closed: Callable[[], None] | None = None + self._on_session_grace_elapsed: Callable[[], None] | None = None + + self._destroyed = False + + @property + def next_seq(self) -> int: + """The next seq the other side should see from us. + + Returns the seq of the first unacked message in the buffer, + or our current seq if the buffer is empty. + """ + if self.send_buffer: + return self.send_buffer[0].seq + return self.seq + + def construct_msg( + self, partial: PartialTransportMessage + ) -> TransportMessage: + """Construct a full TransportMessage from a partial one. + + Fills in id, from, to, seq, ack and increments seq. + """ + msg = TransportMessage( + id=generate_id(), + from_=self.from_id, + to=self.to_id, + seq=self.seq, + ack=self.ack, + payload=partial.payload, + stream_id=partial.stream_id, + control_flags=partial.control_flags, + service_name=partial.service_name, + procedure_name=partial.procedure_name, + tracing=partial.tracing, + ) + self.seq += 1 + return msg + + def send(self, partial: PartialTransportMessage) -> tuple[bool, str]: + """Construct and send a message. + + When connected, sends immediately over the wire and buffers. + When disconnected, only buffers. + + Returns (True, msg_id) on success, (False, reason) on failure. + """ + msg = self.construct_msg(partial) + self.send_buffer.append(msg) + + if self.state == SessionState.CONNECTED and self._ws is not None: + ok, result = self._send_over_wire(msg) + if not ok: + return False, result + return True, msg.id + + def _send_over_wire(self, msg: TransportMessage) -> tuple[bool, str]: + """Serialize and send a message over the current connection.""" + ok, result = self.codec.to_buffer(msg) + if not ok: + return False, result # type: ignore[return-value] + try: + assert self._ws is not None + # websockets library uses async send, but we schedule it + asyncio.get_event_loop().call_soon( + lambda data=result: self._do_ws_send(data) + ) + return True, msg.id + except Exception as e: + return False, f"Failed to send: {e}" + + def _do_ws_send(self, data: bytes) -> None: + """Actually send data over the WebSocket.""" + if self._ws is not None and not self._destroyed: + try: + asyncio.ensure_future(self._ws.send(data)) + except Exception as e: + logger.error("WebSocket send error: %s", e) + + def send_buffered_messages(self) -> tuple[bool, str | None]: + """Retransmit all buffered messages over the current connection. + + Called after a successful reconnection handshake. + """ + for msg in self.send_buffer: + ok, reason = self._send_over_wire(msg) + if not ok: + return False, reason + return True, None + + def update_bookkeeping(self, their_ack: int, their_seq: int) -> None: + """Update seq/ack bookkeeping based on an incoming message. + + - Removes acknowledged messages from the send buffer. + - Updates our ack to their_seq + 1. + - Resets the heartbeat miss timeout. + """ + # Remove acked messages from send buffer + self.send_buffer = [m for m in self.send_buffer if m.seq >= their_ack] + # Update our ack + self.ack = their_seq + 1 + # Reset heartbeat miss timer + self._reset_heartbeat_miss_timeout() + + def send_heartbeat(self) -> None: + """Send a heartbeat message.""" + self.send(heartbeat_message()) + + def start_active_heartbeat(self, loop: asyncio.AbstractEventLoop) -> None: + """Start sending heartbeats at the configured interval (server behavior).""" + self._is_actively_heartbeating = True + interval = self.options.heartbeat_interval_ms / 1000.0 + + async def _heartbeat_loop(): + try: + while not self._destroyed and self.state == SessionState.CONNECTED: + await asyncio.sleep(interval) + if not self._destroyed and self.state == SessionState.CONNECTED: + self.send_heartbeat() + except asyncio.CancelledError: + pass + + self._heartbeat_task = loop.create_task(_heartbeat_loop()) + + def start_heartbeat_miss_timeout(self, loop: asyncio.AbstractEventLoop) -> None: + """Start the missing heartbeat timeout.""" + miss_duration = ( + self.options.heartbeats_until_dead + * self.options.heartbeat_interval_ms + / 1000.0 + ) + + async def _miss_timeout(): + try: + await asyncio.sleep(miss_duration) + if not self._destroyed and self._on_connection_closed: + logger.debug( + "Session %s: heartbeat miss timeout, closing connection", + self.id, + ) + self._on_connection_closed() + except asyncio.CancelledError: + pass + + if self._heartbeat_miss_task: + self._heartbeat_miss_task.cancel() + self._heartbeat_miss_task = loop.create_task(_miss_timeout()) + + def _reset_heartbeat_miss_timeout(self) -> None: + """Reset the heartbeat miss timer.""" + if self._heartbeat_miss_task: + self._heartbeat_miss_task.cancel() + self._heartbeat_miss_task = None + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + self.start_heartbeat_miss_timeout(loop) + except RuntimeError: + pass + + def start_grace_period(self, loop: asyncio.AbstractEventLoop) -> None: + """Start the session disconnect grace period. + + If the session is not reconnected within this time, it's destroyed. + """ + grace_ms = self.options.session_disconnect_grace_ms + self._grace_expiry_time = time.monotonic() + grace_ms / 1000.0 + + async def _grace_timeout(): + try: + await asyncio.sleep(grace_ms / 1000.0) + if not self._destroyed and self._on_session_grace_elapsed: + logger.debug( + "Session %s: grace period elapsed, destroying", self.id + ) + self._on_session_grace_elapsed() + except asyncio.CancelledError: + pass + + if self._grace_period_task: + self._grace_period_task.cancel() + self._grace_period_task = loop.create_task(_grace_timeout()) + + def cancel_grace_period(self) -> None: + """Cancel the session disconnect grace period.""" + if self._grace_period_task: + self._grace_period_task.cancel() + self._grace_period_task = None + self._grace_expiry_time = None + + def cancel_heartbeats(self) -> None: + """Cancel all heartbeat-related tasks.""" + if self._heartbeat_task: + self._heartbeat_task.cancel() + self._heartbeat_task = None + if self._heartbeat_miss_task: + self._heartbeat_miss_task.cancel() + self._heartbeat_miss_task = None + self._is_actively_heartbeating = False + + def set_connected(self, ws: Any, loop: asyncio.AbstractEventLoop) -> None: + """Transition to connected state.""" + self.state = SessionState.CONNECTED + self._ws = ws + self.cancel_grace_period() + self.start_heartbeat_miss_timeout(loop) + + def set_disconnected(self, loop: asyncio.AbstractEventLoop) -> None: + """Transition to disconnected state (no connection).""" + self.state = SessionState.NO_CONNECTION + self.cancel_heartbeats() + old_ws = self._ws + self._ws = None + if old_ws is not None: + try: + asyncio.ensure_future(old_ws.close()) + except Exception: + pass + self.start_grace_period(loop) + + def destroy(self) -> None: + """Destroy the session, cleaning up all resources.""" + self._destroyed = True + self.cancel_heartbeats() + self.cancel_grace_period() + if self._ws is not None: + try: + asyncio.ensure_future(self._ws.close()) + except Exception: + pass + self._ws = None + self.send_buffer.clear() + + def create_handshake_request( + self, metadata: Any = None + ) -> TransportMessage: + """Create a handshake request transport message. + + Handshake messages have seq=0, ack=0, controlFlags=0. + """ + payload = handshake_request_payload( + session_id=self.id, + next_expected_seq=self.ack, + next_sent_seq=self.next_seq, + metadata=metadata, + ) + return TransportMessage( + id=generate_id(), + from_=self.from_id, + to=self.to_id, + seq=0, + ack=0, + payload=payload, + stream_id="handshake", + control_flags=0, + ) diff --git a/python-client/river/streams.py b/python-client/river/streams.py new file mode 100644 index 00000000..cbca9601 --- /dev/null +++ b/python-client/river/streams.py @@ -0,0 +1,215 @@ +"""Readable and Writable stream abstractions for River procedures.""" + +from __future__ import annotations + +import asyncio +from typing import Any, Callable, Generic, TypeVar + +T = TypeVar("T") + + +class ReadableBrokenError(Exception): + """Raised when a readable stream is broken.""" + + pass + + +class Readable(Generic[T]): + """Async readable stream for consuming procedure results. + + Supports async iteration via `async for` and explicit read via `next()`. + """ + + def __init__(self) -> None: + self._queue: list[T] = [] + self._closed = False + self._broken = False + self._locked = False + self._waiters: list[asyncio.Future[None]] = [] + + def _push_value(self, value: T) -> None: + """Push a value into the readable stream (internal use).""" + if self._closed: + raise RuntimeError("Cannot push to a closed readable") + self._queue.append(value) + self._notify_waiters() + + def _trigger_close(self) -> None: + """Close the readable stream (internal use).""" + if self._closed: + raise RuntimeError("Readable already closed") + self._closed = True + self._notify_waiters() + + def _notify_waiters(self) -> None: + while self._waiters: + w = self._waiters.pop(0) + if not w.done(): + w.set_result(None) + + def is_readable(self) -> bool: + """Whether the stream can still be iterated (not locked or broken).""" + return not self._locked and not self._broken + + def is_closed(self) -> bool: + """Whether the stream has been closed.""" + return self._closed and len(self._queue) == 0 + + def _has_values_in_queue(self) -> bool: + """Whether there are buffered values waiting to be consumed.""" + return len(self._queue) > 0 + + def break_(self) -> None: + """Break the stream, discarding all queued values. + + If the stream is already closed and the queue is empty, + this is a no-op (the stream is already done). + """ + if self._locked and self._broken: + return + self._locked = True + # If stream is already done (closed + empty), don't signal broken + if self._closed and len(self._queue) == 0: + self._notify_waiters() + return + self._broken = True + self._queue.clear() + self._notify_waiters() + + async def collect(self) -> list[T]: + """Consume all values from the stream until it closes. + + Locks the stream. Raises TypeError if already locked. + Returns the list of all values. + """ + if self._locked: + raise TypeError("Readable is already locked") + self._locked = True + results: list[T] = [] + async for item in self._iterate(): + results.append(item) + return results + + async def next(self) -> tuple[bool, T | None]: + """Read the next value from the stream. + + Returns (False, value) if a value is available. + Returns (True, None) if the stream is done. + """ + async for item in self._iterate(): + return False, item + return True, None + + async def _iterate(self): + """Internal async generator.""" + self._locked = True + while True: + if self._broken: + yield {"ok": False, "payload": {"code": "READABLE_BROKEN", "message": "stream was broken"}} + return + + if self._queue: + yield self._queue.pop(0) + continue + + if self._closed: + return + + # Wait for more data + loop = asyncio.get_event_loop() + fut: asyncio.Future[None] = loop.create_future() + self._waiters.append(fut) + await fut + + def __aiter__(self): + if self._locked: + raise TypeError("Readable is already locked") + self._locked = True + return _ReadableIterator(self) + + +class _ReadableIterator: + """Async iterator for Readable that cleans up on break/close. + + Unlike an async generator, this class handles ``__del__`` + synchronously, ensuring the queue is cleared when a for-await + loop breaks out. + """ + + def __init__(self, readable: Readable) -> None: + self._readable = readable + self._done = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self._done: + raise StopAsyncIteration + + r = self._readable + while True: + if r._broken: + val = { + "ok": False, + "payload": { + "code": "READABLE_BROKEN", + "message": "stream was broken", + }, + } + # After yielding the broken error, the iterator is done + self._done = True + return val + + if r._queue: + return r._queue.pop(0) + + if r._closed: + raise StopAsyncIteration + + loop = asyncio.get_event_loop() + fut: asyncio.Future[None] = loop.create_future() + r._waiters.append(fut) + await fut + + def __del__(self): + # Synchronous cleanup when the iterator is GC'd (e.g. break in for-await) + self._readable._queue.clear() + + +class Writable(Generic[T]): + """Writable stream for sending procedure requests. + + Wraps a write callback and a close callback. + """ + + def __init__( + self, + write_cb: Callable[[T], None], + close_cb: Callable[[], None] | None = None, + ) -> None: + self._write_cb = write_cb + self._close_cb = close_cb + self._closed = False + + def write(self, value: T) -> None: + """Write a value to the stream.""" + if self._closed: + raise RuntimeError("Cannot write to a closed writable") + self._write_cb(value) + + def close(self, value: T | None = None) -> None: + """Close the stream, optionally writing a final value.""" + if self._closed: + return # Idempotent + self._closed = True + if value is not None: + self._write_cb(value) + if self._close_cb: + self._close_cb() + + def is_writable(self) -> bool: + return not self._closed + + def is_closed(self) -> bool: + return self._closed diff --git a/python-client/river/transport.py b/python-client/river/transport.py new file mode 100644 index 00000000..5ca7ca72 --- /dev/null +++ b/python-client/river/transport.py @@ -0,0 +1,568 @@ +"""Client transport layer for the River protocol. + +Manages WebSocket connections, session lifecycle, handshake, +reconnection with backoff, and message dispatch. +""" + +from __future__ import annotations + +import asyncio +import logging +import math +import random +import time +from typing import Any, Callable + +from river.codec import Codec, CodecMessageAdapter, NaiveJsonCodec +from river.session import Session, SessionOptions, SessionState, DEFAULT_SESSION_OPTIONS +from river.types import ( + ControlFlags, + PartialTransportMessage, + TransportMessage, + generate_id, + is_ack, + is_stream_cancel, + is_stream_close, + is_stream_open, + RETRIABLE_HANDSHAKE_CODES, + FATAL_HANDSHAKE_CODES, + UNEXPECTED_DISCONNECT_CODE, + err_result, +) + +logger = logging.getLogger(__name__) + + +class EventDispatcher: + """Simple event dispatcher with typed event names.""" + + def __init__(self) -> None: + self._handlers: dict[str, set[Callable]] = {} + + def add_listener(self, event: str, handler: Callable) -> None: + if event not in self._handlers: + self._handlers[event] = set() + self._handlers[event].add(handler) + + def remove_listener(self, event: str, handler: Callable) -> None: + if event in self._handlers: + self._handlers[event].discard(handler) + + def dispatch(self, event: str, data: Any = None) -> None: + if event in self._handlers: + # Copy to avoid mutation during iteration + for handler in list(self._handlers[event]): + try: + handler(data) + except Exception as e: + logger.error("Event handler error for %s: %s", event, e) + + def listener_count(self, event: str) -> int: + return len(self._handlers.get(event, set())) + + +class LeakyBucketRateLimit: + """Rate limiter with exponential backoff for connection retries.""" + + def __init__( + self, + base_interval_ms: float = 150, + max_jitter_ms: float = 200, + max_backoff_ms: float = 32_000, + attempt_budget_capacity: int = 5, + budget_restore_interval_ms: float = 200, + ) -> None: + self.base_interval_ms = base_interval_ms + self.max_jitter_ms = max_jitter_ms + self.max_backoff_ms = max_backoff_ms + self.attempt_budget_capacity = attempt_budget_capacity + self.budget_restore_interval_ms = budget_restore_interval_ms + self.budget_consumed: int = 0 + self._restore_task: asyncio.Task | None = None + + def has_budget(self) -> bool: + return self.budget_consumed < self.attempt_budget_capacity + + def get_backoff_ms(self) -> float: + if self.budget_consumed == 0: + return 0 + exponent = max(0, self.budget_consumed - 1) + jitter = random.random() * self.max_jitter_ms + backoff = min( + self.base_interval_ms * (2**exponent), self.max_backoff_ms + ) + return backoff + jitter + + def consume_budget(self) -> None: + self._stop_restore() + self.budget_consumed += 1 + + def start_restoring_budget(self) -> None: + """Start gradually restoring budget after a successful connection.""" + self._stop_restore() + + async def _restore_loop(): + try: + while self.budget_consumed > 0: + await asyncio.sleep( + self.budget_restore_interval_ms / 1000.0 + ) + self.budget_consumed = max(0, self.budget_consumed - 1) + except asyncio.CancelledError: + pass + + try: + loop = asyncio.get_event_loop() + self._restore_task = loop.create_task(_restore_loop()) + except RuntimeError: + pass + + def _stop_restore(self) -> None: + if self._restore_task: + self._restore_task.cancel() + self._restore_task = None + + def reset(self) -> None: + self.budget_consumed = 0 + self._stop_restore() + + +class WebSocketClientTransport: + """Client-side transport managing WebSocket connections and sessions. + + Handles connection lifecycle, handshakes, reconnection with backoff, + heartbeat echo, and message dispatch. + """ + + def __init__( + self, + ws_url: str | Callable[..., str], + client_id: str | None = None, + server_id: str | None = None, + codec: Codec | None = None, + options: SessionOptions | None = None, + handshake_metadata: Any = None, + connect_on_invoke: bool = True, + eagerly_connect: bool = False, + ) -> None: + self.client_id = client_id or generate_id() + self.server_id = server_id or "SERVER" + self._ws_url = ws_url + self._codec = codec or NaiveJsonCodec() + self._codec_adapter = CodecMessageAdapter(self._codec) + self.options = options or DEFAULT_SESSION_OPTIONS + self._handshake_metadata = handshake_metadata + self._connect_on_invoke = connect_on_invoke + + # State + self._status: str = "open" # 'open' | 'closed' + self.sessions: dict[str, Session] = {} # to_id -> Session + self._events = EventDispatcher() + self._retry_budget = LeakyBucketRateLimit() + self._reconnect_on_connection_drop = True + + # Connection tasks + self._connect_tasks: dict[str, asyncio.Task] = {} + + self._loop: asyncio.AbstractEventLoop | None = None + + def get_status(self) -> str: + return self._status + + def _get_loop(self) -> asyncio.AbstractEventLoop: + if self._loop is None: + self._loop = asyncio.get_event_loop() + return self._loop + + # --- Event API --- + + def add_event_listener(self, event: str, handler: Callable) -> None: + self._events.add_listener(event, handler) + + def remove_event_listener(self, event: str, handler: Callable) -> None: + self._events.remove_listener(event, handler) + + # --- Session Management --- + + def _get_or_create_session(self, to: str) -> Session: + """Get an existing session or create a new unconnected one.""" + if to in self.sessions: + return self.sessions[to] + session = Session( + session_id=generate_id(), + from_id=self.client_id, + to_id=to, + codec=self._codec_adapter, + options=self.options, + ) + session._on_session_grace_elapsed = lambda: self._on_session_grace_elapsed(to) + self.sessions[to] = session + self._events.dispatch( + "sessionStatus", {"status": "created", "session": session} + ) + return session + + def _delete_session(self, to: str, emit_closing: bool = True) -> None: + """Delete a session and clean up.""" + session = self.sessions.pop(to, None) + if session is None: + return + if emit_closing: + self._events.dispatch( + "sessionStatus", {"status": "closing", "session": session} + ) + session.destroy() + self._events.dispatch( + "sessionStatus", {"status": "closed", "session": session} + ) + + def _on_session_grace_elapsed(self, to: str) -> None: + """Called when a session's grace period expires.""" + logger.debug("Session grace period elapsed for %s", to) + self._delete_session(to) + + # --- Connection Flow --- + + def connect(self, to: str | None = None) -> None: + """Initiate a connection to the given server. + + Follows the state transition: + NoConnection -> BackingOff -> Connecting -> Handshaking -> Connected + """ + to = to or self.server_id + if self._status != "open": + return + + session = self._get_or_create_session(to) + if session.state != SessionState.NO_CONNECTION: + return # Already connecting/connected + + if not self._retry_budget.has_budget(): + self._events.dispatch( + "protocolError", + {"type": "conn_retry_exceeded", "message": "Retries exceeded"}, + ) + return + + backoff_ms = self._retry_budget.get_backoff_ms() + self._retry_budget.consume_budget() + + # Schedule the connection attempt after backoff + loop = self._get_loop() + session.state = SessionState.BACKING_OFF + + async def _do_connect(): + try: + if backoff_ms > 0: + await asyncio.sleep(backoff_ms / 1000.0) + + if self._status != "open" or session._destroyed: + return + + session.state = SessionState.CONNECTING + ws = await self._create_connection(to) + + if session._destroyed: + await ws.close() + return + + session.state = SessionState.HANDSHAKING + await self._do_handshake(session, ws, to) + except asyncio.CancelledError: + pass + except Exception as e: + logger.debug("Connection attempt failed for %s: %s", to, e) + if not session._destroyed: + self._on_connection_failed(to) + + task = loop.create_task(_do_connect()) + self._connect_tasks[to] = task + + async def _create_connection(self, to: str) -> Any: + """Create a new WebSocket connection.""" + import websockets # type: ignore[import-untyped] + + url = self._ws_url if isinstance(self._ws_url, str) else self._ws_url(to) + + ws = await asyncio.wait_for( + websockets.connect(url, max_size=None, ping_interval=None, ping_timeout=None), + timeout=self.options.connection_timeout_ms / 1000.0, + ) + return ws + + async def _do_handshake( + self, session: Session, ws: Any, to: str + ) -> None: + """Perform the handshake on a newly connected WebSocket.""" + # Send handshake request + hs_msg = session.create_handshake_request( + metadata=self._handshake_metadata + ) + ok, buf = self._codec_adapter.to_buffer(hs_msg) + if not ok: + logger.error("Failed to encode handshake: %s", buf) + await ws.close() + self._on_connection_failed(to) + return + + await ws.send(buf) + + # Wait for handshake response + try: + response_bytes = await asyncio.wait_for( + ws.recv(), timeout=self.options.handshake_timeout_ms / 1000.0 + ) + except (asyncio.TimeoutError, Exception) as e: + logger.debug("Handshake timeout/error for %s: %s", to, e) + await ws.close() + self._on_connection_failed(to) + return + + if isinstance(response_bytes, str): + response_bytes = response_bytes.encode("utf-8") + + ok, result = self._codec_adapter.from_buffer(response_bytes) + if not ok: + logger.error("Failed to decode handshake response: %s", result) + await ws.close() + self._on_connection_failed(to) + return + + response_msg: TransportMessage = result # type: ignore[assignment] + payload = response_msg.payload + + # Validate handshake response + if ( + not isinstance(payload, dict) + or payload.get("type") != "HANDSHAKE_RESP" + ): + logger.error("Invalid handshake response payload") + await ws.close() + self._on_connection_failed(to) + return + + status = payload.get("status", {}) + if not status.get("ok"): + code = status.get("code", "UNKNOWN") + reason = status.get("reason", "Unknown reason") + logger.debug( + "Handshake rejected for %s: %s (%s)", to, reason, code + ) + await ws.close() + + if code in RETRIABLE_HANDSHAKE_CODES: + # Session state mismatch - destroy session and retry + self._delete_session(to) + self._try_reconnecting(to) + else: + self._events.dispatch( + "protocolError", + { + "type": "handshake_failed", + "message": reason, + "code": code, + }, + ) + self._on_connection_failed(to) + return + + # Check session ID match + resp_session_id = status.get("sessionId") + if resp_session_id != session.id: + # Server assigned a different session - old session is stale + logger.debug( + "Session ID mismatch: expected %s, got %s", + session.id, + resp_session_id, + ) + # The server lost our session state; destroy old and create new + self._delete_session(to, emit_closing=True) + self._try_reconnecting(to) + return + + # Handshake successful + loop = self._get_loop() + session.set_connected(ws, loop) + self._events.dispatch( + "sessionTransition", + {"state": SessionState.CONNECTED, "id": session.id}, + ) + + # Retransmit buffered messages + ok, err = session.send_buffered_messages() + if not ok: + logger.error("Failed to send buffered messages: %s", err) + self._events.dispatch( + "protocolError", + {"type": "message_send_failure", "message": err}, + ) + self._delete_session(to) + return + + # Start restoring retry budget + self._retry_budget.start_restoring_budget() + + # Start listening for messages + self._start_message_listener(session, ws, to) + + def _start_message_listener( + self, session: Session, ws: Any, to: str + ) -> None: + """Start the async message listener on the WebSocket.""" + loop = self._get_loop() + + session._on_connection_closed = lambda: self._on_connection_dropped(to) + + async def _listen(): + try: + async for raw_msg in ws: + if session._destroyed: + break + if isinstance(raw_msg, str): + raw_msg = raw_msg.encode("utf-8") + self._on_message_data(session, raw_msg, to) + except Exception as e: + if not session._destroyed: + logger.debug( + "WebSocket error for session %s: %s", session.id, e + ) + finally: + if not session._destroyed: + self._on_connection_dropped(to) + + loop.create_task(_listen()) + + def _on_message_data( + self, session: Session, raw: bytes, to: str + ) -> None: + """Handle raw bytes received from the WebSocket.""" + ok, result = self._codec_adapter.from_buffer(raw) + if not ok: + self._events.dispatch( + "protocolError", + {"type": "invalid_message", "message": result}, + ) + return + + msg: TransportMessage = result # type: ignore[assignment] + + # Check message ordering + if msg.seq != session.ack: + if msg.seq < session.ack: + # Duplicate - discard silently + return + else: + # Future message - close connection to force re-handshake + logger.debug( + "Seq out of order: expected %d, got %d. Closing.", + session.ack, + msg.seq, + ) + if session._ws: + asyncio.ensure_future(session._ws.close()) + return + + # Update bookkeeping + session.update_bookkeeping(msg.ack, msg.seq) + + # Dispatch non-heartbeat messages + if not is_ack(msg.control_flags): + self._events.dispatch("message", msg) + return + + # If this is a heartbeat and we're not actively heartbeating (client), + # echo back + if not session._is_actively_heartbeating: + session.send_heartbeat() + + def _on_connection_dropped(self, to: str) -> None: + """Handle a dropped connection.""" + session = self.sessions.get(to) + if session is None or session._destroyed: + return + if session.state != SessionState.CONNECTED: + return + + loop = self._get_loop() + session.set_disconnected(loop) + self._events.dispatch( + "sessionTransition", + {"state": SessionState.NO_CONNECTION, "id": session.id}, + ) + + if self._reconnect_on_connection_drop: + self._try_reconnecting(to) + + def _on_connection_failed(self, to: str) -> None: + """Handle a failed connection attempt.""" + session = self.sessions.get(to) + if session is None or session._destroyed: + return + + loop = self._get_loop() + session.state = SessionState.NO_CONNECTION + + if self._reconnect_on_connection_drop: + self._try_reconnecting(to) + + def _try_reconnecting(self, to: str) -> None: + """Try to reconnect to the server.""" + if self._status != "open": + return + if not self._reconnect_on_connection_drop: + return + # Use call_soon to break out of the current call stack + loop = self._get_loop() + loop.call_soon(lambda: self.connect(to)) + + # --- Session-Bound Send --- + + def get_session_bound_send_fn( + self, to: str, session_id: str + ) -> Callable[[PartialTransportMessage], str]: + """Get a send function scoped to a specific session. + + The send function will raise if the session has been replaced or destroyed. + """ + + def _send(msg: PartialTransportMessage) -> str: + session = self.sessions.get(to) + if session is None: + raise RuntimeError("Session scope ended (closed)") + if session.id != session_id or session._destroyed: + raise RuntimeError("Session scope ended (transition)") + + ok, result = session.send(msg) + if not ok: + raise RuntimeError(f"Send failed: {result}") + return result + + return _send + + # --- Lifecycle --- + + async def close(self) -> None: + """Close the transport and all sessions.""" + if self._status == "closed": + return + self._status = "closed" + + # Cancel all pending connection tasks + for task in self._connect_tasks.values(): + task.cancel() + self._connect_tasks.clear() + + # Delete all sessions + for to in list(self.sessions.keys()): + self._delete_session(to) + + self._retry_budget.reset() + self._events.dispatch("transportStatus", {"status": "closed"}) + + @property + def reconnect_on_connection_drop(self) -> bool: + return self._reconnect_on_connection_drop + + @reconnect_on_connection_drop.setter + def reconnect_on_connection_drop(self, value: bool) -> None: + self._reconnect_on_connection_drop = value diff --git a/python-client/river/types.py b/python-client/river/types.py new file mode 100644 index 00000000..591ed830 --- /dev/null +++ b/python-client/river/types.py @@ -0,0 +1,259 @@ +"""Core types for the River protocol.""" + +from __future__ import annotations + +import string +import random +from dataclasses import dataclass, field +from enum import IntFlag +from typing import Any, TypeVar, Generic, Union + + +# --- ID Generation --- + +_ID_ALPHABET = string.ascii_letters + string.digits +_ID_LENGTH = 12 + + +def generate_id() -> str: + """Generate a nanoid-style random ID (12 chars, alphanumeric).""" + return "".join(random.choices(_ID_ALPHABET, k=_ID_LENGTH)) + + +# --- Control Flags --- + + +class ControlFlags(IntFlag): + """Bit flags for transport message control signals.""" + + AckBit = 0b00001 # 1 - heartbeat/ack only + StreamOpenBit = 0b00010 # 2 - first message of a stream + StreamCancelBit = 0b00100 # 4 - abrupt cancel with ProtocolError payload + StreamClosedBit = 0b01000 # 8 - last message of a stream + + +def is_ack(flags: int) -> bool: + return (flags & ControlFlags.AckBit) == ControlFlags.AckBit + + +def is_stream_open(flags: int) -> bool: + return (flags & ControlFlags.StreamOpenBit) == ControlFlags.StreamOpenBit + + +def is_stream_cancel(flags: int) -> bool: + return (flags & ControlFlags.StreamCancelBit) == ControlFlags.StreamCancelBit + + +def is_stream_close(flags: int) -> bool: + return (flags & ControlFlags.StreamClosedBit) == ControlFlags.StreamClosedBit + + +# --- Transport Message --- + + +@dataclass +class TransportMessage: + """The envelope for all messages sent over the wire.""" + + id: str + from_: str # 'from' is a Python keyword + to: str + seq: int + ack: int + payload: Any + stream_id: str + control_flags: int = 0 + service_name: str | None = None + procedure_name: str | None = None + tracing: dict[str, str] | None = None + + def to_dict(self) -> dict[str, Any]: + """Serialize to a dict matching the wire format.""" + d: dict[str, Any] = { + "id": self.id, + "from": self.from_, + "to": self.to, + "seq": self.seq, + "ack": self.ack, + "payload": self.payload, + "streamId": self.stream_id, + "controlFlags": self.control_flags, + } + if self.service_name is not None: + d["serviceName"] = self.service_name + if self.procedure_name is not None: + d["procedureName"] = self.procedure_name + if self.tracing is not None: + d["tracing"] = self.tracing + return d + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> TransportMessage: + """Deserialize from a wire format dict.""" + return cls( + id=d["id"], + from_=d["from"], + to=d["to"], + seq=d["seq"], + ack=d["ack"], + payload=d["payload"], + stream_id=d["streamId"], + control_flags=d.get("controlFlags", 0), + service_name=d.get("serviceName"), + procedure_name=d.get("procedureName"), + tracing=d.get("tracing"), + ) + + +@dataclass +class PartialTransportMessage: + """A transport message missing id, from, to, seq, ack -- filled in by Session.""" + + payload: Any + stream_id: str + control_flags: int = 0 + service_name: str | None = None + procedure_name: str | None = None + tracing: dict[str, str] | None = None + + +# --- Result Types --- + +T = TypeVar("T") +E = TypeVar("E") + + +@dataclass +class OkResult(Generic[T]): + """Success result.""" + + payload: T + ok: bool = field(default=True, init=False) + + +@dataclass +class ErrResult(Generic[E]): + """Error result.""" + + payload: E + ok: bool = field(default=False, init=False) + + +Result = Union[OkResult[T], ErrResult[E]] + + +def Ok(payload: Any) -> OkResult: + """Create an Ok result.""" + return OkResult(payload=payload) + + +def Err(payload: Any) -> ErrResult: + """Create an Err result.""" + return ErrResult(payload=payload) + + +def ok_result(payload: Any) -> dict[str, Any]: + """Create an ok result dict for wire format.""" + return {"ok": True, "payload": payload} + + +def err_result(code: str, message: str, extras: Any = None) -> dict[str, Any]: + """Create an error result dict for wire format.""" + p: dict[str, Any] = {"code": code, "message": message} + if extras is not None: + p["extras"] = extras + return {"ok": False, "payload": p} + + +# --- Protocol Error Codes --- + +UNEXPECTED_DISCONNECT_CODE = "UNEXPECTED_DISCONNECT" +CANCEL_CODE = "CANCEL" +UNCAUGHT_ERROR_CODE = "UNCAUGHT_ERROR" +INVALID_REQUEST_CODE = "INVALID_REQUEST" + +# --- Protocol Version --- + +PROTOCOL_VERSION = "v2.0" + + +# --- Control Message Helpers --- + + +def handshake_request_payload( + session_id: str, + next_expected_seq: int, + next_sent_seq: int, + metadata: Any = None, +) -> dict[str, Any]: + """Create a handshake request payload.""" + payload: dict[str, Any] = { + "type": "HANDSHAKE_REQ", + "protocolVersion": PROTOCOL_VERSION, + "sessionId": session_id, + "expectedSessionState": { + "nextExpectedSeq": next_expected_seq, + "nextSentSeq": next_sent_seq, + }, + } + if metadata is not None: + payload["metadata"] = metadata + return payload + + +def handshake_response_ok(session_id: str) -> dict[str, Any]: + return { + "type": "HANDSHAKE_RESP", + "status": {"ok": True, "sessionId": session_id}, + } + + +def ack_payload() -> dict[str, str]: + """Heartbeat/ACK control payload.""" + return {"type": "ACK"} + + +def close_payload() -> dict[str, str]: + """Stream close control payload.""" + return {"type": "CLOSE"} + + +def close_stream_message(stream_id: str) -> PartialTransportMessage: + """Create a close stream partial message.""" + return PartialTransportMessage( + payload=close_payload(), + stream_id=stream_id, + control_flags=ControlFlags.StreamClosedBit, + ) + + +def cancel_message(stream_id: str, error_payload: dict) -> PartialTransportMessage: + """Create a cancel stream partial message.""" + return PartialTransportMessage( + payload=error_payload, + stream_id=stream_id, + control_flags=ControlFlags.StreamCancelBit, + ) + + +def heartbeat_message() -> PartialTransportMessage: + """Create a heartbeat partial message.""" + return PartialTransportMessage( + payload=ack_payload(), + stream_id="heartbeat", + control_flags=ControlFlags.AckBit, + ) + + +# --- Handshake Error Codes --- + +RETRIABLE_HANDSHAKE_CODES = frozenset({"SESSION_STATE_MISMATCH"}) +FATAL_HANDSHAKE_CODES = frozenset( + { + "MALFORMED_HANDSHAKE_META", + "MALFORMED_HANDSHAKE", + "PROTOCOL_VERSION_MISMATCH", + "REJECTED_BY_CUSTOM_HANDLER", + "REJECTED_UNSUPPORTED_CLIENT", + } +) diff --git a/python-client/tests/__init__.py b/python-client/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python-client/tests/conftest.py b/python-client/tests/conftest.py new file mode 100644 index 00000000..4299e9d2 --- /dev/null +++ b/python-client/tests/conftest.py @@ -0,0 +1,118 @@ +"""Pytest fixtures for River Python client tests. + +Manages the lifecycle of a TypeScript test server process that the +Python client connects to. +""" + +from __future__ import annotations + +import asyncio +import os +import re +import signal +import subprocess +import time +from typing import Generator + +import pytest + + +TESTS_DIR = os.path.dirname(__file__) +SERVER_TS = os.path.join(TESTS_DIR, "test_server.ts") +SERVER_MJS = os.path.join(TESTS_DIR, "test_server.mjs") +RIVER_ROOT = os.path.abspath(os.path.join(TESTS_DIR, "..", "..")) +ESBUILD = os.path.join(RIVER_ROOT, "node_modules", ".bin", "esbuild") + + +def _build_test_server() -> None: + """Bundle test_server.ts -> test_server.mjs using esbuild. + + esbuild handles the river repo's bundler-style module resolution at + build time, producing a single ESM file that plain ``node`` can run. + """ + result = subprocess.run( + [ + ESBUILD, + SERVER_TS, + "--bundle", + "--platform=node", + "--format=esm", + f"--outfile={SERVER_MJS}", + # keep heavy deps external so the bundle stays small and + # we reuse whatever is already in node_modules + "--external:ws", + "--external:@sinclair/typebox", + ], + cwd=RIVER_ROOT, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError( + f"esbuild failed ({result.returncode}):\n{result.stderr}" + ) + + +@pytest.fixture(scope="session") +def event_loop(): + """Create an event loop for the entire test session.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +def river_server_port() -> Generator[int, None, None]: + """Build and start the TypeScript test server, yield its port. + + The server is built once via esbuild and kept alive for the entire + test session. + """ + _build_test_server() + + proc = subprocess.Popen( + ["node", SERVER_MJS], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=RIVER_ROOT, + ) + + # Wait for the server to print the port + port = None + deadline = time.monotonic() + 30 + assert proc.stdout is not None + while time.monotonic() < deadline: + line = proc.stdout.readline().decode("utf-8").strip() + if not line: + if proc.poll() is not None: + stderr = ( + proc.stderr.read().decode("utf-8") if proc.stderr else "" + ) + raise RuntimeError( + f"Test server exited with code {proc.returncode}.\n" + f"stderr: {stderr}" + ) + time.sleep(0.1) + continue + m = re.match(r"RIVER_PORT=(\d+)", line) + if m: + port = int(m.group(1)) + break + + if port is None: + proc.kill() + raise RuntimeError("Failed to get port from test server within 30s") + + yield port + + proc.send_signal(signal.SIGTERM) + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + + +@pytest.fixture +def server_url(river_server_port: int) -> str: + """Return the WebSocket URL for the test server.""" + return f"ws://127.0.0.1:{river_server_port}" diff --git a/python-client/tests/test_e2e.py b/python-client/tests/test_e2e.py new file mode 100644 index 00000000..de5ae8b5 --- /dev/null +++ b/python-client/tests/test_e2e.py @@ -0,0 +1,1226 @@ +"""End-to-end tests for the River Python client. + +Tests the Python client against the TypeScript test server, covering +all four procedure types and core protocol behavior. +""" + +from __future__ import annotations + +import asyncio +import pytest + +from river.client import RiverClient +from river.transport import WebSocketClientTransport +from river.codec import NaiveJsonCodec + + +# -- helpers -- + + +async def make_client(server_url: str, **kwargs) -> RiverClient: + """Create a connected RiverClient.""" + transport = WebSocketClientTransport( + ws_url=server_url, + client_id=None, # auto-generate + server_id="SERVER", + codec=NaiveJsonCodec(), + connect_on_invoke=kwargs.get("connect_on_invoke", True), + eagerly_connect=kwargs.get("eagerly_connect", False), + ) + return RiverClient(transport, server_id="SERVER") + + +async def cleanup_client(client: RiverClient) -> None: + await client.transport.close() + + +# ===================================================================== +# RPC Tests +# ===================================================================== + + +class TestRpc: + @pytest.mark.asyncio + async def test_rpc_basic(self, server_url: str): + """Basic RPC call returns correct result.""" + client = await make_client(server_url) + try: + result = await client.rpc("test", "add", {"n": 3}) + assert result["ok"] is True + assert result["payload"]["result"] == 3 + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_fallible_rpc_success(self, server_url: str): + """Fallible RPC returns Ok on valid input.""" + client = await make_client(server_url) + try: + result = await client.rpc("fallible", "divide", {"a": 10, "b": 2}) + assert result["ok"] is True + assert result["payload"]["result"] == 5.0 + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_fallible_rpc_error(self, server_url: str): + """Fallible RPC returns Err with correct error code.""" + client = await make_client(server_url) + try: + result = await client.rpc("fallible", "divide", {"a": 10, "b": 0}) + assert result["ok"] is False + assert result["payload"]["code"] == "DIV_BY_ZERO" + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_concurrent_rpcs(self, server_url: str): + """Multiple concurrent RPCs all complete correctly.""" + client = await make_client(server_url) + try: + tasks = [ + client.rpc("ordering", "add", {"n": i}) for i in range(10) + ] + results = await asyncio.gather(*tasks) + for i, result in enumerate(results): + assert result["ok"] is True + assert result["payload"]["n"] == i + finally: + await cleanup_client(client) + + +# ===================================================================== +# Stream Tests +# ===================================================================== + + +class TestStream: + @pytest.mark.asyncio + async def test_stream_basic(self, server_url: str): + """Stream echoes messages correctly, skipping ignored ones.""" + client = await make_client(server_url) + try: + stream = client.stream("test", "echo", {}) + + # Write messages + stream.req_writable.write({"msg": "hello", "ignore": False}) + stream.req_writable.write({"msg": "world", "ignore": False}) + stream.req_writable.write({"msg": "skip", "ignore": True}) + stream.req_writable.write({"msg": "end", "ignore": False}) + stream.req_writable.close() + + # Read responses + results = [] + async for msg in stream.res_readable: + results.append(msg) + + assert len(results) == 3 + assert results[0]["ok"] is True + assert results[0]["payload"]["response"] == "hello" + assert results[1]["payload"]["response"] == "world" + assert results[2]["payload"]["response"] == "end" + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_stream_empty(self, server_url: str): + """Stream with immediate close returns no results.""" + client = await make_client(server_url) + try: + stream = client.stream("test", "echo", {}) + stream.req_writable.close() + + results = await stream.res_readable.collect() + assert len(results) == 0 + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_stream_with_init_message(self, server_url: str): + """Stream handler receives the init message.""" + client = await make_client(server_url) + try: + stream = client.stream( + "test", "echoWithPrefix", {"prefix": "test"} + ) + stream.req_writable.write({"msg": "hello", "ignore": False}) + stream.req_writable.write({"msg": "world", "ignore": False}) + stream.req_writable.close() + + results = await stream.res_readable.collect() + assert len(results) == 2 + assert results[0]["payload"]["response"] == "test hello" + assert results[1]["payload"]["response"] == "test world" + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_fallible_stream(self, server_url: str): + """Stream correctly propagates both Ok and Err results.""" + client = await make_client(server_url) + try: + stream = client.stream("fallible", "echo", {}) + + # Normal message + stream.req_writable.write( + {"msg": "hello", "throwResult": False, "throwError": False} + ) + done, msg = await stream.res_readable.next() + assert not done + assert msg["ok"] is True + assert msg["payload"]["response"] == "hello" + + # Error result (service-level error) + stream.req_writable.write( + {"msg": "fail", "throwResult": True, "throwError": False} + ) + done, msg = await stream.res_readable.next() + assert not done + assert msg["ok"] is False + assert msg["payload"]["code"] == "STREAM_ERROR" + + # Uncaught error (causes stream cancel) + stream.req_writable.write( + {"msg": "throw", "throwResult": False, "throwError": True} + ) + done, msg = await stream.res_readable.next() + assert not done + assert msg["ok"] is False + assert msg["payload"]["code"] == "UNCAUGHT_ERROR" + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_concurrent_streams(self, server_url: str): + """Multiple concurrent streams work independently.""" + client = await make_client(server_url) + try: + streams = [] + for _ in range(5): + s = client.stream("test", "echo", {}) + streams.append(s) + + # Write to each stream + for i, s in enumerate(streams): + s.req_writable.write({"msg": f"msg-{i}", "ignore": False}) + s.req_writable.close() + + # Read from each stream + for i, s in enumerate(streams): + results = await s.res_readable.collect() + assert len(results) == 1 + assert results[0]["payload"]["response"] == f"msg-{i}" + finally: + await cleanup_client(client) + + +# ===================================================================== +# Subscription Tests +# ===================================================================== + + +class TestSubscription: + @pytest.mark.asyncio + async def test_subscription_basic(self, server_url: str): + """Subscription receives initial value and updates.""" + client = await make_client(server_url) + try: + sub = client.subscribe("subscribable", "value", {}) + + # Read initial value + done, msg = await sub.res_readable.next() + assert not done + assert msg["ok"] is True + initial_count = msg["payload"]["count"] + + # Trigger an update + add_result = await client.rpc("subscribable", "add", {"n": 1}) + assert add_result["ok"] is True + + # Read updated value + done, msg = await sub.res_readable.next() + assert not done + assert msg["ok"] is True + assert msg["payload"]["count"] == initial_count + 1 + finally: + await cleanup_client(client) + + +# ===================================================================== +# Upload Tests +# ===================================================================== + + +class TestUpload: + @pytest.mark.asyncio + async def test_upload_basic(self, server_url: str): + """Upload sums multiple values correctly.""" + client = await make_client(server_url) + try: + upload = client.upload("uploadable", "addMultiple", {}) + upload.req_writable.write({"n": 1}) + upload.req_writable.write({"n": 2}) + upload.req_writable.close() + + result = await upload.finalize() + assert result["ok"] is True + assert result["payload"]["result"] == 3 + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_upload_empty(self, server_url: str): + """Upload with no data returns zero.""" + client = await make_client(server_url) + try: + upload = client.upload("uploadable", "addMultiple", {}) + upload.req_writable.close() + + result = await upload.finalize() + assert result["ok"] is True + assert result["payload"]["result"] == 0 + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_upload_with_init_message(self, server_url: str): + """Upload handler receives the init message.""" + client = await make_client(server_url) + try: + upload = client.upload( + "uploadable", "addMultipleWithPrefix", {"prefix": "test"} + ) + upload.req_writable.write({"n": 1}) + upload.req_writable.write({"n": 2}) + upload.req_writable.close() + + result = await upload.finalize() + assert result["ok"] is True + assert result["payload"]["result"] == "test 3" + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_upload_server_cancel(self, server_url: str): + """Upload receives server-initiated cancel when limit exceeded.""" + client = await make_client(server_url) + try: + upload = client.upload("uploadable", "cancellableAdd", {}) + upload.req_writable.write({"n": 9}) + upload.req_writable.write({"n": 1}) + # Don't close - server should cancel + + result = await upload.finalize() + assert result["ok"] is False + assert result["payload"]["code"] == "CANCEL" + finally: + await cleanup_client(client) + + +# ===================================================================== +# Disconnect Tests +# ===================================================================== + + +class TestDisconnect: + @pytest.mark.asyncio + async def test_rpc_on_closed_transport(self, server_url: str): + """RPC on a closed transport returns UNEXPECTED_DISCONNECT.""" + client = await make_client(server_url) + await client.transport.close() + + result = await client.rpc("test", "add", {"n": 1}) + assert result["ok"] is False + assert result["payload"]["code"] == "UNEXPECTED_DISCONNECT" + + @pytest.mark.asyncio + async def test_stream_on_closed_transport(self, server_url: str): + """Stream on a closed transport returns UNEXPECTED_DISCONNECT.""" + client = await make_client(server_url) + await client.transport.close() + + stream = client.stream("test", "echo", {}) + done, msg = await stream.res_readable.next() + assert not done + assert msg["ok"] is False + assert msg["payload"]["code"] == "UNEXPECTED_DISCONNECT" + + @pytest.mark.asyncio + async def test_upload_on_closed_transport(self, server_url: str): + """Upload on a closed transport returns UNEXPECTED_DISCONNECT.""" + client = await make_client(server_url) + await client.transport.close() + + upload = client.upload("uploadable", "addMultiple", {}) + assert not upload.req_writable.is_writable() + result = await upload.finalize() + assert result["ok"] is False + assert result["payload"]["code"] == "UNEXPECTED_DISCONNECT" + + @pytest.mark.asyncio + async def test_subscription_on_closed_transport(self, server_url: str): + """Subscription on a closed transport returns UNEXPECTED_DISCONNECT.""" + client = await make_client(server_url) + await client.transport.close() + + sub = client.subscribe("subscribable", "value", {}) + done, msg = await sub.res_readable.next() + assert not done + assert msg["ok"] is False + assert msg["payload"]["code"] == "UNEXPECTED_DISCONNECT" + + +# ===================================================================== +# Client-Initiated Cancellation Tests +# ===================================================================== + + +class TestClientCancellation: + """Tests for client-initiated cancellation via abort signal. + + Uses the cancel.blocking* handlers on the test server which never resolve, + allowing us to test that the client abort properly sends CANCEL and + receives the CANCEL result. + """ + + @pytest.mark.asyncio + async def test_cancel_rpc(self, server_url: str): + """Client abort on RPC returns CANCEL error.""" + client = await make_client(server_url) + try: + abort_evt = asyncio.Event() + + async def do_abort(): + await asyncio.sleep(0.2) + abort_evt.set() + + asyncio.ensure_future(do_abort()) + result = await client.rpc( + "cancel", "blockingRpc", {}, abort_signal=abort_evt + ) + assert result["ok"] is False + assert result["payload"]["code"] == "CANCEL" + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_cancel_stream(self, server_url: str): + """Client abort on stream returns CANCEL error.""" + client = await make_client(server_url) + try: + abort_evt = asyncio.Event() + stream = client.stream( + "cancel", "blockingStream", {}, abort_signal=abort_evt + ) + # Give server time to receive and process the init message + await asyncio.sleep(0.2) + abort_evt.set() + await asyncio.sleep(0) + + results = await stream.res_readable.collect() + assert len(results) == 1 + assert results[0]["ok"] is False + assert results[0]["payload"]["code"] == "CANCEL" + assert not stream.req_writable.is_writable() + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_cancel_upload(self, server_url: str): + """Client abort on upload returns CANCEL error.""" + client = await make_client(server_url) + try: + abort_evt = asyncio.Event() + upload = client.upload( + "cancel", "blockingUpload", {}, abort_signal=abort_evt + ) + # Give server time to receive + await asyncio.sleep(0.2) + abort_evt.set() + + result = await upload.finalize() + assert result["ok"] is False + assert result["payload"]["code"] == "CANCEL" + assert not upload.req_writable.is_writable() + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_cancel_subscription(self, server_url: str): + """Client abort on subscription returns CANCEL error.""" + client = await make_client(server_url) + try: + abort_evt = asyncio.Event() + sub = client.subscribe( + "cancel", "blockingSubscription", {}, abort_signal=abort_evt + ) + # Give server time to receive + await asyncio.sleep(0.2) + abort_evt.set() + await asyncio.sleep(0) + + done, msg = await sub.res_readable.next() + assert not done + assert msg["ok"] is False + assert msg["payload"]["code"] == "CANCEL" + finally: + await cleanup_client(client) + + +# ===================================================================== +# Idempotent Close / Post-Close Safety Tests +# ===================================================================== + + +class TestIdempotentClose: + """Tests that closing/aborting after completion is a safe no-op.""" + + @pytest.mark.asyncio + async def test_stream_idempotent_close(self, server_url: str): + """Closing and aborting a stream after it finished is safe.""" + client = await make_client(server_url) + try: + abort_evt = asyncio.Event() + stream = client.stream( + "test", "echo", {}, abort_signal=abort_evt + ) + stream.req_writable.write({"msg": "abc", "ignore": False}) + stream.req_writable.close() + + done, msg = await stream.res_readable.next() + assert not done + assert msg["ok"] is True + assert msg["payload"]["response"] == "abc" + + # Wait for server close to be received + await asyncio.sleep(0.1) + + # Abort after stream completed - should be a no-op + abort_evt.set() + await asyncio.sleep(0.05) + + # Drain any remaining messages - should be done or at most a cancel + done, val = await stream.res_readable.next() + # Either the stream is done, or we got a cancel (both ok) + if not done: + assert val["ok"] is False + + # "Accidentally" close again - no crash + stream.req_writable.close() + abort_evt.set() + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_subscription_idempotent_close(self, server_url: str): + """Aborting a subscription after it was already aborted is safe.""" + client = await make_client(server_url) + try: + abort_evt = asyncio.Event() + sub = client.subscribe( + "subscribable", "value", {}, abort_signal=abort_evt + ) + # Read initial value + done, msg = await sub.res_readable.next() + assert not done + assert msg["ok"] is True + + # Abort + abort_evt.set() + await asyncio.sleep(0.05) + + # Read the cancel + done, msg = await sub.res_readable.next() + assert not done + assert msg["ok"] is False + assert msg["payload"]["code"] == "CANCEL" + + # "Accidentally" abort again + abort_evt.set() + finally: + await cleanup_client(client) + + @pytest.mark.asyncio + async def test_cancellation_after_transport_close(self, server_url: str): + """Closing/aborting after transport close doesn't crash.""" + client = await make_client(server_url) + try: + abort_evt = asyncio.Event() + stream = client.stream( + "test", "echo", {}, abort_signal=abort_evt + ) + stream.req_writable.write({"msg": "1", "ignore": False}) + done, msg = await stream.res_readable.next() + assert not done + assert msg["payload"]["response"] == "1" + + # Close the transport + await client.transport.close() + await asyncio.sleep(0.05) + + # Closing writable after transport close should be safe + stream.req_writable.close() + # Aborting after transport close should be safe + abort_evt.set() + await asyncio.sleep(0.05) + # No crash = success + finally: + # Transport already closed + pass + + +# ===================================================================== +# Eagerly Connect Test +# ===================================================================== + + +class TestEagerConnect: + @pytest.mark.asyncio + async def test_eagerly_connect(self, server_url: str): + """eagerlyConnect creates a connection before any procedure call.""" + transport = WebSocketClientTransport( + ws_url=server_url, + server_id="SERVER", + codec=NaiveJsonCodec(), + eagerly_connect=True, + ) + client = RiverClient(transport, server_id="SERVER", eagerly_connect=True) + try: + # Wait for the connection to be established + await asyncio.sleep(0.5) + # Should have a session now + assert len(transport.sessions) > 0 + # Verify the connection works by making a call + result = await client.rpc("test", "add", {"n": 1}) + assert result["ok"] is True + finally: + await transport.close() + + +# ===================================================================== +# Codec Tests +# ===================================================================== + + +class TestCodec: + @pytest.mark.asyncio + async def test_json_codec_rpc(self, server_url: str): + """JSON codec works for basic RPC.""" + transport = WebSocketClientTransport( + ws_url=server_url, + server_id="SERVER", + codec=NaiveJsonCodec(), + ) + client = RiverClient(transport, server_id="SERVER") + try: + result = await client.rpc("test", "add", {"n": 5}) + assert result["ok"] is True + finally: + await transport.close() + + @pytest.mark.asyncio + async def test_binary_codec_roundtrip(self): + """Binary (msgpack) codec encodes and decodes transport messages.""" + from river.codec import BinaryCodec, CodecMessageAdapter + from river.types import TransportMessage + + adapter = CodecMessageAdapter(BinaryCodec()) + msg = TransportMessage( + id="test123", + from_="client", + to="server", + seq=1, + ack=0, + payload={"data": "hello"}, + stream_id="s1", + control_flags=0, + ) + ok, buf = adapter.to_buffer(msg) + assert ok is True + ok, decoded = adapter.from_buffer(buf) + assert ok is True + assert decoded.payload == {"data": "hello"} + + +# ===================================================================== +# Stream Unit Tests +# ===================================================================== + + +class TestReadable: + @pytest.mark.asyncio + async def test_readable_close(self): + """Closing a readable makes it done.""" + from river.streams import Readable + + r: Readable = Readable() + r._trigger_close() + assert r.is_closed() + + @pytest.mark.asyncio + async def test_readable_iterate(self): + """Can iterate over pushed values.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + r._push_value({"ok": True, "payload": 2}) + r._trigger_close() + + results = await r.collect() + assert len(results) == 2 + assert results[0]["payload"] == 1 + assert results[1]["payload"] == 2 + + @pytest.mark.asyncio + async def test_readable_push_after_close_raises(self): + """Pushing to a closed readable raises.""" + from river.streams import Readable + + r: Readable = Readable() + r._trigger_close() + with pytest.raises(RuntimeError): + r._push_value({"ok": True, "payload": 1}) + + @pytest.mark.asyncio + async def test_readable_double_close_raises(self): + """Closing a readable twice raises.""" + from river.streams import Readable + + r: Readable = Readable() + r._trigger_close() + with pytest.raises(RuntimeError): + r._trigger_close() + + @pytest.mark.asyncio + async def test_readable_break(self): + """Breaking a readable yields broken error on next read.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + # Grab iterator before break (since break locks the stream) + done, val = await r.next() + assert not done + assert val["payload"] == 1 + r.break_() + done, val = await r.next() + assert not done + assert val["ok"] is False + assert val["payload"]["code"] == "READABLE_BROKEN" + r._trigger_close() + + @pytest.mark.asyncio + async def test_readable_async_for(self): + """Works with async for loop.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": "a"}) + r._push_value({"ok": True, "payload": "b"}) + r._trigger_close() + + values = [] + async for item in r: + values.append(item) + assert len(values) == 2 + + +class TestWritable: + def test_writable_write(self): + """Write callback is invoked.""" + from river.streams import Writable + + received = [] + w: Writable = Writable(write_cb=received.append) + w.write(1) + w.write(2) + assert received == [1, 2] + + def test_writable_close(self): + """Close callback is invoked once.""" + from river.streams import Writable + + close_count = [0] + w: Writable = Writable( + write_cb=lambda x: None, + close_cb=lambda: close_count.__setitem__(0, close_count[0] + 1), + ) + assert w.is_writable() + w.close() + assert not w.is_writable() + assert close_count[0] == 1 + + def test_writable_idempotent_close(self): + """Closing multiple times only invokes callback once.""" + from river.streams import Writable + + close_count = [0] + w: Writable = Writable( + write_cb=lambda x: None, + close_cb=lambda: close_count.__setitem__(0, close_count[0] + 1), + ) + w.close() + w.close() + w.close() + assert close_count[0] == 1 + + def test_writable_write_after_close_raises(self): + """Writing after close raises.""" + from river.streams import Writable + + w: Writable = Writable(write_cb=lambda x: None) + w.close() + with pytest.raises(RuntimeError): + w.write(42) + + def test_writable_close_with_value(self): + """Close with a final value writes it before closing.""" + from river.streams import Writable + + received = [] + w: Writable = Writable(write_cb=received.append) + w.close(42) + assert received == [42] + assert w.is_closed() + + +# ===================================================================== +# Types Unit Tests +# ===================================================================== + + +class TestTypes: + def test_generate_id_length(self): + """Generated IDs are 12 characters.""" + from river.types import generate_id + + for _ in range(100): + assert len(generate_id()) == 12 + + def test_generate_id_unique(self): + """Generated IDs are unique.""" + from river.types import generate_id + + ids = {generate_id() for _ in range(1000)} + assert len(ids) == 1000 + + def test_control_flags(self): + """Control flag bit operations work correctly.""" + from river.types import ( + ControlFlags, + is_ack, + is_stream_open, + is_stream_cancel, + is_stream_close, + ) + + assert is_ack(ControlFlags.AckBit) + assert not is_ack(0) + assert is_stream_open(ControlFlags.StreamOpenBit) + assert is_stream_close(ControlFlags.StreamClosedBit) + assert is_stream_cancel(ControlFlags.StreamCancelBit) + + # Combined flags + combined = ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit + assert is_stream_open(combined) + assert is_stream_close(combined) + assert not is_ack(combined) + + def test_transport_message_roundtrip(self): + """TransportMessage serializes and deserializes correctly.""" + from river.types import TransportMessage + + msg = TransportMessage( + id="test123", + from_="client1", + to="server1", + seq=5, + ack=3, + payload={"data": "hello"}, + stream_id="stream1", + control_flags=0, + service_name="myService", + procedure_name="myProc", + ) + d = msg.to_dict() + assert d["from"] == "client1" + assert d["to"] == "server1" + assert d["serviceName"] == "myService" + + msg2 = TransportMessage.from_dict(d) + assert msg2.from_ == "client1" + assert msg2.seq == 5 + assert msg2.service_name == "myService" + + +# ===================================================================== +# Codec Unit Tests +# ===================================================================== + + +class TestReadableLocking: + """Tests for Readable stream locking semantics (mirrors TS streams.test.ts).""" + + @pytest.mark.asyncio + async def test_lock_on_aiter(self): + """__aiter__ locks the stream; second call raises TypeError.""" + from river.streams import Readable + + r: Readable = Readable() + r.__aiter__() + assert not r.is_readable() + with pytest.raises(TypeError): + r.__aiter__() + r._trigger_close() + + @pytest.mark.asyncio + async def test_lock_on_collect(self): + """collect() locks the stream; __aiter__ raises TypeError.""" + from river.streams import Readable + + r: Readable = Readable() + # Don't await - just start collect (it will block waiting for close) + collect_task = asyncio.ensure_future(r.collect()) + await asyncio.sleep(0) # yield to let collect start + assert not r.is_readable() + with pytest.raises(TypeError): + r.__aiter__() + r._trigger_close() + await collect_task + + @pytest.mark.asyncio + async def test_lock_on_break(self): + """break_() locks the stream; __aiter__ raises TypeError.""" + from river.streams import Readable + + r: Readable = Readable() + r.break_() + assert not r.is_readable() + with pytest.raises(TypeError): + r.__aiter__() + r._trigger_close() + + @pytest.mark.asyncio + async def test_raw_iter_from_aiter(self): + """Can use the raw iterator from __aiter__.""" + from river.streams import Readable + + r: Readable = Readable() + it = r.__aiter__() + next_p = it.__anext__() + r._push_value({"ok": True, "payload": 1}) + val = await next_p + assert val == {"ok": True, "payload": 1} + next_p2 = it.__anext__() + r._trigger_close() + with pytest.raises(StopAsyncIteration): + await next_p2 + + +class TestReadableIteration: + """Tests for Readable iteration edge cases (mirrors TS streams.test.ts).""" + + @pytest.mark.asyncio + async def test_values_pushed_before_close(self): + """Can iterate values that were pushed before close.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + r._push_value({"ok": True, "payload": 2}) + r._push_value({"ok": True, "payload": 3}) + r._trigger_close() + done, val = await r.next() + assert not done and val["payload"] == 1 + done, val = await r.next() + assert not done and val["payload"] == 2 + done, val = await r.next() + assert not done and val["payload"] == 3 + done, val = await r.next() + assert done + + @pytest.mark.asyncio + async def test_eager_iteration(self): + """Read before push resolves in order.""" + from river.streams import Readable + + r: Readable = Readable() + # Start reading before values are pushed + t1 = asyncio.ensure_future(r.next()) + t2 = asyncio.ensure_future(r.next()) + # Give tasks a chance to start waiting + await asyncio.sleep(0) + r._push_value({"ok": True, "payload": 1}) + r._push_value({"ok": True, "payload": 2}) + done1, val1 = await t1 + done2, val2 = await t2 + assert not done1 and val1["payload"] == 1 + assert not done2 and val2["payload"] == 2 + # Third read + close + t3 = asyncio.ensure_future(r.next()) + await asyncio.sleep(0) + r._push_value({"ok": True, "payload": 3}) + r._trigger_close() + done3, val3 = await t3 + assert not done3 and val3["payload"] == 3 + done4, _ = await r.next() + assert done4 + + @pytest.mark.asyncio + async def test_not_resolve_until_push(self): + """Pending next() doesn't resolve until push or close.""" + from river.streams import Readable + + r: Readable = Readable() + next_p = asyncio.ensure_future(r.next()) + # Should not resolve yet + result = await asyncio.wait_for( + asyncio.shield(next_p), timeout=0.01 + ) if False else None + done = next_p.done() + assert not done, "next() should not resolve before push" + + r._push_value({"ok": True, "payload": 1}) + await asyncio.sleep(0) + done_v, val = await next_p + assert not done_v and val["payload"] == 1 + + # isDone should not resolve until close + done_p = asyncio.ensure_future(r.next()) + await asyncio.sleep(0.01) + assert not done_p.done(), "next() should not resolve before close" + r._trigger_close() + done_v2, _ = await done_p + assert done_v2 + + @pytest.mark.asyncio + async def test_collect_after_close(self): + """collect() returns all values when called after close.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + r._push_value({"ok": True, "payload": 2}) + r._push_value({"ok": True, "payload": 3}) + r._trigger_close() + results = await r.collect() + assert len(results) == 3 + assert [v["payload"] for v in results] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_collect_waits_for_close(self): + """collect() doesn't resolve until the stream is closed.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + collect_task = asyncio.ensure_future(r.collect()) + r._push_value({"ok": True, "payload": 2}) + r._push_value({"ok": True, "payload": 3}) + await asyncio.sleep(0.01) + assert not collect_task.done(), "collect should not resolve before close" + r._push_value({"ok": True, "payload": 4}) + r._trigger_close() + results = await collect_task + assert len(results) == 4 + assert [v["payload"] for v in results] == [1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_async_for_with_break(self): + """Breaking out of async for mid-stream stops iteration.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + r._push_value({"ok": True, "payload": 2}) + assert r._has_values_in_queue() + values = [] + async for item in r: + values.append(item) + assert r._has_values_in_queue() + break + # After break, remaining values should be discarded (broken) + assert not r._has_values_in_queue() + + @pytest.mark.asyncio + async def test_error_results_in_iteration(self): + """Error results are yielded as part of iteration.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + r._push_value({"ok": True, "payload": 2}) + r._push_value( + {"ok": False, "payload": {"code": "SOME_ERROR", "message": "err"}} + ) + r._trigger_close() + results = [] + async for item in r: + results.append(item) + assert len(results) == 3 + assert results[0]["ok"] is True + assert results[1]["ok"] is True + assert results[2]["ok"] is False + assert results[2]["payload"]["code"] == "SOME_ERROR" + + +class TestReadableBreakVariants: + """Tests for Readable break() edge cases (mirrors TS streams.test.ts).""" + + @pytest.mark.asyncio + async def test_break_signals_next(self): + """break() signals the next read call.""" + from river.streams import Readable + + r: Readable = Readable() + r.break_() + done, val = await r.next() + assert not done + assert val["ok"] is False + assert val["payload"]["code"] == "READABLE_BROKEN" + r._trigger_close() + + @pytest.mark.asyncio + async def test_break_signals_pending(self): + """break() signals a pending read.""" + from river.streams import Readable + + r: Readable = Readable() + pending = asyncio.ensure_future(r.next()) + await asyncio.sleep(0) + r.break_() + done, val = await pending + assert not done + assert val["ok"] is False + assert val["payload"]["code"] == "READABLE_BROKEN" + r._trigger_close() + + @pytest.mark.asyncio + async def test_break_with_queued_value(self): + """break() clears queue and yields broken error.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + assert r._has_values_in_queue() + r.break_() + assert not r._has_values_in_queue() + done, val = await r.next() + assert not done + assert val["payload"]["code"] == "READABLE_BROKEN" + r._trigger_close() + + @pytest.mark.asyncio + async def test_break_with_queued_value_after_close(self): + """break() after close with queued values still yields broken error.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + r._trigger_close() + r.break_() + done, val = await r.next() + assert not done + assert val["payload"]["code"] == "READABLE_BROKEN" + + @pytest.mark.asyncio + async def test_break_empty_queue_after_close(self): + """break() after close with empty queue -> done.""" + from river.streams import Readable + + r: Readable = Readable() + r._trigger_close() + r.break_() + done, _ = await r.next() + assert done + + @pytest.mark.asyncio + async def test_break_ends_iteration_midstream(self): + """break() during async for ends iteration.""" + from river.streams import Readable + + r: Readable = Readable() + r._push_value({"ok": True, "payload": 1}) + r._push_value({"ok": True, "payload": 2}) + r._push_value({"ok": True, "payload": 3}) + + results = [] + i = 0 + async for item in r: + if i == 0: + assert item["payload"] == 1 + r.break_() + elif i == 1: + assert item["ok"] is False + assert item["payload"]["code"] == "READABLE_BROKEN" + results.append(item) + i += 1 + assert i == 2 + + +class TestCodecUnit: + def test_json_codec_encode_decode(self): + """JSON codec round-trips correctly.""" + from river.codec import NaiveJsonCodec + + codec = NaiveJsonCodec() + obj = {"key": "value", "num": 42, "nested": {"a": [1, 2, 3]}} + buf = codec.to_buffer(obj) + assert isinstance(buf, bytes) + result = codec.from_buffer(buf) + assert result == obj + + def test_json_codec_bytes_handling(self): + """JSON codec handles bytes via base64.""" + from river.codec import NaiveJsonCodec + + codec = NaiveJsonCodec() + obj = {"data": b"\x00\x01\x02\xff"} + buf = codec.to_buffer(obj) + result = codec.from_buffer(buf) + assert result["data"] == b"\x00\x01\x02\xff" + + def test_binary_codec_encode_decode(self): + """Binary (msgpack) codec round-trips correctly.""" + from river.codec import BinaryCodec + + codec = BinaryCodec() + obj = {"key": "value", "num": 42, "nested": {"a": [1, 2, 3]}} + buf = codec.to_buffer(obj) + assert isinstance(buf, bytes) + result = codec.from_buffer(buf) + assert result == obj + + def test_codec_adapter_valid(self): + """CodecMessageAdapter encodes and decodes transport messages.""" + from river.codec import CodecMessageAdapter, NaiveJsonCodec + from river.types import TransportMessage + + adapter = CodecMessageAdapter(NaiveJsonCodec()) + msg = TransportMessage( + id="abc", + from_="c1", + to="s1", + seq=0, + ack=0, + payload={"type": "ACK"}, + stream_id="heartbeat", + control_flags=1, + ) + ok, buf = adapter.to_buffer(msg) + assert ok is True + + ok, result = adapter.from_buffer(buf) + assert ok is True + assert result.id == "abc" + assert result.from_ == "c1" + + def test_codec_adapter_invalid_buffer(self): + """CodecMessageAdapter returns error on invalid bytes.""" + from river.codec import CodecMessageAdapter, NaiveJsonCodec + + adapter = CodecMessageAdapter(NaiveJsonCodec()) + ok, result = adapter.from_buffer(b"not valid json") + assert ok is False + assert isinstance(result, str) diff --git a/python-client/tests/test_server.ts b/python-client/tests/test_server.ts new file mode 100644 index 00000000..4f5ab982 --- /dev/null +++ b/python-client/tests/test_server.ts @@ -0,0 +1,424 @@ +/** + * Standalone test server for the Python River client test suite. + * + * Starts a WebSocket server with the standard test services and prints + * the port to stdout so the Python test harness can connect. + * + * Usage (from river repo root): + * npx tsx --tsconfig python-client/tsconfig.tsx.json python-client/tests/test_server.ts + */ +import http from 'node:http'; +import { WebSocketServer } from 'ws'; +import { WebSocketServerTransport } from '../../transport/impls/ws/server'; +import { + createServer, + createServiceSchema, + Procedure, + Ok, + Err, +} from '../../router'; +import { Type } from '@sinclair/typebox'; + +const ServiceSchema = createServiceSchema(); + +// ------------------------------------------------------------------- +// TestService – mirrors the TS TestServiceSchema +// ------------------------------------------------------------------- +let count = 0; + +const TestServiceSchema = ServiceSchema.define({ + add: Procedure.rpc({ + requestInit: Type.Object({ n: Type.Number() }), + responseData: Type.Object({ result: Type.Number() }), + responseError: Type.Never(), + async handler({ reqInit }) { + count += reqInit.n; + return Ok({ result: count }); + }, + }), + echo: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({ + msg: Type.String(), + ignore: Type.Optional(Type.Boolean()), + }), + responseData: Type.Object({ response: Type.String() }), + responseError: Type.Never(), + async handler({ reqReadable, resWritable }) { + for await (const result of reqReadable) { + if (!result.ok) break; + const val = result.payload; + if (val.ignore) continue; + resWritable.write(Ok({ response: val.msg })); + } + resWritable.close(); + }, + }), + echoWithPrefix: Procedure.stream({ + requestInit: Type.Object({ prefix: Type.String() }), + requestData: Type.Object({ + msg: Type.String(), + ignore: Type.Optional(Type.Boolean()), + }), + responseData: Type.Object({ response: Type.String() }), + responseError: Type.Never(), + async handler({ reqInit, reqReadable, resWritable }) { + for await (const result of reqReadable) { + if (!result.ok) break; + const val = result.payload; + if (val.ignore) continue; + resWritable.write(Ok({ response: `${reqInit.prefix} ${val.msg}` })); + } + resWritable.close(); + }, + }), +}); + +// ------------------------------------------------------------------- +// OrderingService – for message ordering tests +// ------------------------------------------------------------------- +const msgs: number[] = []; + +const OrderingServiceSchema = ServiceSchema.define({ + add: Procedure.rpc({ + requestInit: Type.Object({ n: Type.Number() }), + responseData: Type.Object({ n: Type.Number() }), + responseError: Type.Never(), + async handler({ reqInit }) { + msgs.push(reqInit.n); + return Ok({ n: reqInit.n }); + }, + }), + getAll: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({ msgs: Type.Array(Type.Number()) }), + responseError: Type.Never(), + async handler() { + return Ok({ msgs: [...msgs] }); + }, + }), +}); + +// ------------------------------------------------------------------- +// FallibleService – service-level errors +// ------------------------------------------------------------------- +const FallibleServiceSchema = ServiceSchema.define({ + divide: Procedure.rpc({ + requestInit: Type.Object({ a: Type.Number(), b: Type.Number() }), + responseData: Type.Object({ result: Type.Number() }), + responseError: Type.Union([ + Type.Object({ + code: Type.Literal('DIV_BY_ZERO'), + message: Type.String(), + }), + Type.Object({ + code: Type.Literal('INFINITY'), + message: Type.String(), + }), + ]), + async handler({ reqInit }) { + if (reqInit.b === 0) { + return Err({ + code: 'DIV_BY_ZERO' as const, + message: 'Cannot divide by zero', + }); + } + const result = reqInit.a / reqInit.b; + if (!isFinite(result)) { + return Err({ + code: 'INFINITY' as const, + message: 'Result is infinity', + }); + } + return Ok({ result }); + }, + }), + echo: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({ + msg: Type.String(), + throwResult: Type.Optional(Type.Boolean()), + throwError: Type.Optional(Type.Boolean()), + }), + responseData: Type.Object({ response: Type.String() }), + responseError: Type.Object({ + code: Type.Literal('STREAM_ERROR'), + message: Type.String(), + }), + async handler({ reqReadable, resWritable }) { + for await (const result of reqReadable) { + if (!result.ok) break; + const val = result.payload; + if (val.throwError) { + throw new Error('uncaught error'); + } + if (val.throwResult) { + resWritable.write( + Err({ code: 'STREAM_ERROR' as const, message: 'stream error' }), + ); + continue; + } + resWritable.write(Ok({ response: val.msg })); + } + resWritable.close(); + }, + }), +}); + +// ------------------------------------------------------------------- +// SubscribableService – subscriptions +// ------------------------------------------------------------------- +let subCount = 0; +type SubListener = (val: number) => void; +const subListeners = new Set(); + +const SubscribableServiceSchema = ServiceSchema.define({ + add: Procedure.rpc({ + requestInit: Type.Object({ n: Type.Number() }), + responseData: Type.Object({ result: Type.Number() }), + responseError: Type.Never(), + async handler({ reqInit }) { + subCount += reqInit.n; + for (const l of subListeners) l(subCount); + return Ok({ result: subCount }); + }, + }), + value: Procedure.subscription({ + requestInit: Type.Object({}), + responseData: Type.Object({ count: Type.Number() }), + responseError: Type.Never(), + async handler({ resWritable, ctx }) { + const listener: SubListener = (val) => { + resWritable.write(Ok({ count: val })); + }; + // Send initial value + resWritable.write(Ok({ count: subCount })); + subListeners.add(listener); + ctx.signal.addEventListener('abort', () => { + subListeners.delete(listener); + resWritable.close(); + }); + }, + }), +}); + +// ------------------------------------------------------------------- +// UploadableService – uploads +// ------------------------------------------------------------------- +const UploadableServiceSchema = ServiceSchema.define({ + addMultiple: Procedure.upload({ + requestInit: Type.Object({}), + requestData: Type.Object({ n: Type.Number() }), + responseData: Type.Object({ result: Type.Number() }), + responseError: Type.Never(), + async handler({ reqReadable }) { + let total = 0; + for await (const result of reqReadable) { + if (!result.ok) break; + total += result.payload.n; + } + return Ok({ result: total }); + }, + }), + addMultipleWithPrefix: Procedure.upload({ + requestInit: Type.Object({ prefix: Type.String() }), + requestData: Type.Object({ n: Type.Number() }), + responseData: Type.Object({ result: Type.String() }), + responseError: Type.Never(), + async handler({ reqInit, reqReadable }) { + let total = 0; + for await (const result of reqReadable) { + if (!result.ok) break; + total += result.payload.n; + } + return Ok({ result: `${reqInit.prefix} ${total}` }); + }, + }), + cancellableAdd: Procedure.upload({ + requestInit: Type.Object({}), + requestData: Type.Object({ n: Type.Number() }), + responseData: Type.Object({ result: Type.Number() }), + responseError: Type.Object({ + code: Type.Literal('CANCEL'), + message: Type.String(), + }), + async handler({ reqReadable, ctx }) { + let total = 0; + for await (const result of reqReadable) { + if (!result.ok) break; + total += result.payload.n; + if (total >= 10) { + ctx.cancel(); + return Err({ + code: 'CANCEL' as const, + message: 'total exceeds limit', + }); + } + } + return Ok({ result: total }); + }, + }), +}); + +// ------------------------------------------------------------------- +// CancellationService – handlers that block forever for cancel tests +// ------------------------------------------------------------------- +const CancellationServiceSchema = ServiceSchema.define({ + blockingRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + responseError: Type.Never(), + async handler({ ctx }) { + // Block until cancelled + return new Promise((resolve) => { + ctx.signal.addEventListener('abort', () => { + // Handler will be cancelled by the framework, nothing to resolve + }); + }); + }, + }), + blockingStream: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + responseError: Type.Never(), + async handler({ ctx }) { + return new Promise(() => { + // never resolves + }); + }, + }), + blockingUpload: Procedure.upload({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + responseError: Type.Never(), + async handler({ ctx }) { + return new Promise(() => { + // never resolves + }); + }, + }), + blockingSubscription: Procedure.subscription({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + responseError: Type.Never(), + async handler({ ctx }) { + return new Promise(() => { + // never resolves + }); + }, + }), + // RPC that resolves normally (for clean handler cancellation) + immediateRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({ done: Type.Boolean() }), + responseError: Type.Never(), + async handler() { + return Ok({ done: true }); + }, + }), + // Stream that writes one response and closes (for clean handler cancel) + immediateStream: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({ done: Type.Boolean() }), + responseError: Type.Never(), + async handler({ reqReadable, resWritable }) { + resWritable.write(Ok({ done: true })); + for await (const result of reqReadable) { + if (!result.ok) break; + } + resWritable.close(); + }, + }), + // Upload that resolves immediately + immediateUpload: Procedure.upload({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({ done: Type.Boolean() }), + responseError: Type.Never(), + async handler({ reqReadable }) { + for await (const result of reqReadable) { + if (!result.ok) break; + } + return Ok({ done: true }); + }, + }), + // Subscription that closes immediately + immediateSubscription: Procedure.subscription({ + requestInit: Type.Object({}), + responseData: Type.Object({ done: Type.Boolean() }), + responseError: Type.Never(), + async handler({ resWritable }) { + resWritable.write(Ok({ done: true })); + resWritable.close(); + }, + }), + // Stream that sends N responses then closes (for idempotent close tests) + countedStream: Procedure.stream({ + requestInit: Type.Object({ total: Type.Number() }), + requestData: Type.Object({}), + responseData: Type.Object({ i: Type.Number() }), + responseError: Type.Never(), + async handler({ reqInit, reqReadable, resWritable }) { + for (let i = 0; i < reqInit.total; i++) { + resWritable.write(Ok({ i })); + } + // Wait for client to close the request stream + for await (const result of reqReadable) { + if (!result.ok) break; + } + resWritable.close(); + }, + }), +}); + +// ------------------------------------------------------------------- +// Boot the server +// ------------------------------------------------------------------- +const services = { + test: TestServiceSchema, + ordering: OrderingServiceSchema, + fallible: FallibleServiceSchema, + subscribable: SubscribableServiceSchema, + uploadable: UploadableServiceSchema, + cancel: CancellationServiceSchema, +}; + +async function main() { + const httpServer = http.createServer(); + const port = await new Promise((resolve, reject) => { + httpServer.listen(0, '127.0.0.1', () => { + const addr = httpServer.address(); + if (typeof addr === 'object' && addr) resolve(addr.port); + else reject(new Error("couldn't get port")); + }); + }); + + const wss = new WebSocketServer({ server: httpServer }); + const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); + const _server = createServer(serverTransport, services); + + // Signal that the server is ready by printing the port + process.stdout.write(`RIVER_PORT=${port}\n`); + + // Keep the server alive + process.on('SIGTERM', () => { + _server.close().then(() => { + httpServer.close(); + process.exit(0); + }); + }); + process.on('SIGINT', () => { + _server.close().then(() => { + httpServer.close(); + process.exit(0); + }); + }); +} + +main().catch((err) => { + console.error('Failed to start test server:', err); + process.exit(1); +});