-
Notifications
You must be signed in to change notification settings - Fork 135
fix: wait for WebSocket terminal status to prevent event loss #1832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ec4b0fb
25aa320
0b407d8
ffff022
13ad596
bd8e311
83261dd
5287639
939f74b
69610ab
35ebdd1
802a991
622b711
d03112e
3fa5173
f054895
cc6c1be
6e354e6
cd3346e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import time | ||
| import uuid | ||
| from collections.abc import Mapping | ||
| from queue import Empty, Queue | ||
| from typing import SupportsIndex, overload | ||
| from urllib.parse import urlparse | ||
|
|
||
|
|
@@ -555,6 +556,7 @@ class RemoteConversation(BaseConversation): | |
| _client: httpx.Client | ||
| _hook_processor: HookEventProcessor | None | ||
| _cleanup_initiated: bool | ||
| _terminal_status_queue: Queue[str] # Thread-safe queue for terminal status from WS | ||
| delete_on_close: bool = False | ||
|
|
||
| def __init__( | ||
|
|
@@ -609,6 +611,7 @@ def __init__( | |
| self._client = workspace.client | ||
| self._hook_processor = None | ||
| self._cleanup_initiated = False | ||
| self._terminal_status_queue: Queue[str] = Queue() | ||
|
|
||
| should_create = conversation_id is None | ||
| if conversation_id is not None: | ||
|
|
@@ -708,8 +711,21 @@ def __init__( | |
| # No visualization (visualizer is None) | ||
| self._visualizer = None | ||
|
|
||
| # Add a callback that signals when run completes via WebSocket | ||
| # This ensures we wait for all events to be delivered before run() returns | ||
| def run_complete_callback(event: Event) -> None: | ||
| if isinstance(event, ConversationStateUpdateEvent): | ||
| if event.key == "execution_status": | ||
| try: | ||
| status = ConversationExecutionStatus(event.value) | ||
| if status.is_terminal(): | ||
| self._terminal_status_queue.put(event.value) | ||
| except ValueError: | ||
| pass # Unknown status value, ignore | ||
|
|
||
| # Compose all callbacks into a single callback | ||
| composed_callback = BaseConversation.compose_callbacks(self._callbacks) | ||
| all_callbacks = self._callbacks + [run_complete_callback] | ||
| composed_callback = BaseConversation.compose_callbacks(all_callbacks) | ||
xingyaoww marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Initialize WebSocket client for callbacks | ||
| self._ws_client = WebSocketCallbackClient( | ||
|
|
@@ -862,6 +878,14 @@ def run( | |
| Raises: | ||
| ConversationRunError: If the run fails or times out. | ||
| """ | ||
| # Drain any stale terminal status events from previous runs. | ||
| # This prevents stale events from causing early returns. | ||
| while True: | ||
| try: | ||
| self._terminal_status_queue.get_nowait() | ||
| except Empty: | ||
| break | ||
|
|
||
| # Trigger a run on the server using the dedicated run endpoint. | ||
| # Let the server tell us if it's already running (409), avoiding an extra GET. | ||
| try: | ||
|
|
@@ -889,10 +913,20 @@ def _wait_for_run_completion( | |
| poll_interval: float = 1.0, | ||
| timeout: float = 1800.0, | ||
| ) -> None: | ||
| """Poll the server until the conversation is no longer running. | ||
| """Wait for the conversation run to complete. | ||
|
|
||
| This method waits for the run to complete by listening for the terminal | ||
| status event via WebSocket. This ensures all events are delivered before | ||
| returning, avoiding the race condition where polling sees "finished" | ||
| status before WebSocket delivers the final events. | ||
|
|
||
| As a fallback, it also polls the server periodically. If the WebSocket | ||
| is delayed or disconnected, we return after multiple consecutive polls | ||
| show a terminal status, and reconcile events to catch any that were | ||
| missed via WebSocket. | ||
|
|
||
| Args: | ||
| poll_interval: Time in seconds between status polls. | ||
| poll_interval: Time in seconds between status polls (fallback). | ||
| timeout: Maximum time in seconds to wait. | ||
|
|
||
| Raises: | ||
|
|
@@ -901,6 +935,14 @@ def _wait_for_run_completion( | |
| responses are retried until timeout. | ||
| """ | ||
| start_time = time.monotonic() | ||
| consecutive_terminal_polls = 0 | ||
| # Return after this many consecutive terminal polls (fallback for WS issues). | ||
| # We use 3 polls to balance latency vs reliability: | ||
| # - 1 poll could be a transient state during shutdown | ||
| # - 2 polls might still catch a race condition | ||
| # - 3 polls (with default 1s interval = 3s total) provides high confidence | ||
| # that the run is truly complete while keeping fallback latency reasonable | ||
| TERMINAL_POLL_THRESHOLD = 3 | ||
|
|
||
| while True: | ||
| elapsed = time.monotonic() - start_time | ||
|
|
@@ -913,20 +955,57 @@ def _wait_for_run_completion( | |
| ), | ||
| ) | ||
|
|
||
| # Wait for either: | ||
| # 1. WebSocket delivers terminal status event (preferred) | ||
| # 2. Poll interval expires (fallback - check status via REST) | ||
| try: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Suggestion: Minor timeout edge case - if Example:
Suggested fix: remaining = timeout - elapsed
wait_time = min(poll_interval, remaining)
try:
ws_status = self._terminal_status_queue.get(timeout=wait_time)This ensures we respect the timeout more precisely, though the current behavior is probably acceptable for most use cases.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Suggestion: Minor timeout edge case - if Example:
Suggested fix: remaining = timeout - elapsed
wait_time = min(poll_interval, remaining)
try:
ws_status = self._terminal_status_queue.get(timeout=wait_time)This ensures we respect the timeout more precisely, though the current behavior is probably acceptable for most use cases. |
||
| ws_status = self._terminal_status_queue.get(timeout=poll_interval) | ||
| # Handle ERROR/STUCK states - raises ConversationRunError | ||
| self._handle_conversation_status(ws_status) | ||
|
|
||
| logger.info( | ||
| "Run completed via WebSocket notification " | ||
| "(status: %s, elapsed: %.1fs)", | ||
| ws_status, | ||
| elapsed, | ||
| ) | ||
| return | ||
| except Empty: | ||
| pass # Queue.get() timed out, fall through to REST polling | ||
|
|
||
| # Poll the server for status as a health check and fallback. | ||
| # This catches ERROR/STUCK states that need immediate attention, | ||
| # and provides a fallback if WebSocket is delayed/disconnected. | ||
| try: | ||
| status = self._poll_status_once() | ||
| except Exception as exc: | ||
| self._handle_poll_exception(exc) | ||
| consecutive_terminal_polls = 0 # Reset on error | ||
| else: | ||
| if self._handle_conversation_status(status): | ||
| logger.info( | ||
| "Run completed with status: %s (elapsed: %.1fs)", | ||
| status, | ||
| elapsed, | ||
| ) | ||
| return | ||
|
|
||
| time.sleep(poll_interval) | ||
| # Raises ConversationRunError for ERROR/STUCK states | ||
| self._handle_conversation_status(status) | ||
xingyaoww marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Track consecutive terminal polls as a fallback for WS issues. | ||
| # If WebSocket is delayed/disconnected, we return after multiple | ||
| # consecutive polls confirm the terminal status. | ||
| if status and ConversationExecutionStatus(status).is_terminal(): | ||
| consecutive_terminal_polls += 1 | ||
| if consecutive_terminal_polls >= TERMINAL_POLL_THRESHOLD: | ||
| logger.info( | ||
| "Run completed via REST fallback after %d consecutive " | ||
| "terminal polls (status: %s, elapsed: %.1fs). " | ||
| "Reconciling events...", | ||
| consecutive_terminal_polls, | ||
| status, | ||
| elapsed, | ||
| ) | ||
| # Reconcile events to catch any that were missed via WS. | ||
| # This is only called in the fallback path, so it doesn't | ||
| # add overhead in the common case where WS works. | ||
| self._state.events.reconcile() | ||
| return | ||
| else: | ||
| consecutive_terminal_polls = 0 | ||
|
|
||
| def _poll_status_once(self) -> str | None: | ||
| """Fetch the current execution status from the remote conversation.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟠 Important: This fix relies on a critical assumption that WebSocket events are processed sequentially in the callback.
Why this matters:
ConversationStateUpdateEvent(finished)is processed and put in the queue BEFOREActionEvent(finish)is fully processed and added to state, the race condition still existsSuggestion:
Add a comment documenting this critical assumption:
Consider adding a test that validates sequential execution order.