diff --git a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py index 515d58368e..7fc3a5fa8c 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py @@ -737,7 +737,9 @@ def _wait_for_run_completion( timeout: Maximum time in seconds to wait. Raises: - ConversationRunError: If the wait times out. + ConversationRunError: If the run fails, the conversation disappears, + or the wait times out. Transient network errors, 429s, and 5xx + responses are retried until timeout. """ start_time = time.monotonic() @@ -753,41 +755,77 @@ def _wait_for_run_completion( ) try: - resp = _send_request( - self._client, - "GET", - f"/api/conversations/{self._id}", - timeout=30, - ) - info = resp.json() - status = info.get("execution_status") - - if status != ConversationExecutionStatus.RUNNING.value: - if status == ConversationExecutionStatus.ERROR.value: - detail = self._get_last_error_detail() - raise ConversationRunError( - self._id, - RuntimeError( - detail or "Remote conversation ended with error" - ), - ) - if status == ConversationExecutionStatus.STUCK.value: - raise ConversationRunError( - self._id, - RuntimeError("Remote conversation got stuck"), - ) + status = self._poll_status_once() + except Exception as exc: + self._handle_poll_exception(exc) + else: + if self._handle_conversation_status(status): logger.info( - f"Run completed with status: {status} (elapsed: {elapsed:.1f}s)" + "Run completed with status: %s (elapsed: %.1fs)", + status, + elapsed, ) return - except Exception as e: - # Log but continue polling - transient network errors shouldn't - # stop us from waiting for the run to complete - logger.warning(f"Error polling status (will retry): {e}") - time.sleep(poll_interval) + def _poll_status_once(self) -> str | None: + """Fetch the current execution status from the remote conversation.""" + resp = _send_request( + self._client, + "GET", + f"/api/conversations/{self._id}", + timeout=30, + ) + info = resp.json() + return info.get("execution_status") + + def _handle_conversation_status(self, status: str | None) -> bool: + """Handle non-running statuses; return True if the run is complete.""" + if status == ConversationExecutionStatus.RUNNING.value: + return False + if status == ConversationExecutionStatus.ERROR.value: + detail = self._get_last_error_detail() + raise ConversationRunError( + self._id, + RuntimeError(detail or "Remote conversation ended with error"), + ) + if status == ConversationExecutionStatus.STUCK.value: + raise ConversationRunError( + self._id, + RuntimeError("Remote conversation got stuck"), + ) + return True + + def _handle_poll_exception(self, exc: Exception) -> None: + """Classify polling exceptions into retryable vs terminal failures.""" + if isinstance(exc, httpx.HTTPStatusError): + status_code = exc.response.status_code + reason = exc.response.reason_phrase + if status_code == 404: + raise ConversationRunError( + self._id, + RuntimeError( + "Remote conversation not found (404). " + "The runtime may have been deleted." + ), + ) from exc + if 400 <= status_code < 500 and status_code != 429: + raise ConversationRunError( + self._id, + RuntimeError(f"Polling failed with HTTP {status_code} {reason}"), + ) from exc + logger.warning( + "Error polling status (will retry): HTTP %d %s", + status_code, + reason, + ) + return + if isinstance(exc, httpx.RequestError): + logger.warning(f"Error polling status (will retry): {exc}") + return + raise ConversationRunError(self._id, exc) from exc + def _get_last_error_detail(self) -> str | None: """Return the most recent ConversationErrorEvent detail, if available.""" events = self._state.events diff --git a/tests/sdk/conversation/remote/test_remote_conversation.py b/tests/sdk/conversation/remote/test_remote_conversation.py index 40c1c70962..56a81dc85b 100644 --- a/tests/sdk/conversation/remote/test_remote_conversation.py +++ b/tests/sdk/conversation/remote/test_remote_conversation.py @@ -8,6 +8,7 @@ from pydantic import SecretStr from openhands.sdk.agent import Agent +from openhands.sdk.conversation.exceptions import ConversationRunError from openhands.sdk.conversation.impl.remote_conversation import RemoteConversation from openhands.sdk.conversation.secret_registry import SecretValue from openhands.sdk.conversation.visualizer import DefaultConversationVisualizer @@ -461,6 +462,94 @@ def custom_side_effect(method, url, **kwargs): f"Should have polled 3 times (2 running + 1 finished), got {poll_count[0]}" ) + @patch( + "openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient" + ) + def test_remote_conversation_run_error_status_raises(self, mock_ws_client): + """Test that error status raises ConversationRunError.""" + conversation_id = str(uuid.uuid4()) + mock_client_instance = self.setup_mock_client(conversation_id=conversation_id) + + original_side_effect = mock_client_instance.request.side_effect + + def custom_side_effect(method, url, **kwargs): + if method == "GET" and url == f"/api/conversations/{conversation_id}": + response = Mock() + response.raise_for_status.return_value = None + response.json.return_value = { + "id": conversation_id, + "execution_status": "error", + } + return response + return original_side_effect(method, url, **kwargs) + + mock_client_instance.request.side_effect = custom_side_effect + + mock_ws_instance = Mock() + mock_ws_client.return_value = mock_ws_instance + + conversation = RemoteConversation(agent=self.agent, workspace=self.workspace) + with pytest.raises(ConversationRunError) as exc_info: + conversation.run(poll_interval=0.01) + assert "error" in str(exc_info.value).lower() + + @patch( + "openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient" + ) + def test_remote_conversation_run_stuck_status_raises(self, mock_ws_client): + """Test that stuck status raises ConversationRunError.""" + conversation_id = str(uuid.uuid4()) + mock_client_instance = self.setup_mock_client(conversation_id=conversation_id) + + original_side_effect = mock_client_instance.request.side_effect + + def custom_side_effect(method, url, **kwargs): + if method == "GET" and url == f"/api/conversations/{conversation_id}": + response = Mock() + response.raise_for_status.return_value = None + response.json.return_value = { + "id": conversation_id, + "execution_status": "stuck", + } + return response + return original_side_effect(method, url, **kwargs) + + mock_client_instance.request.side_effect = custom_side_effect + + mock_ws_instance = Mock() + mock_ws_client.return_value = mock_ws_instance + + conversation = RemoteConversation(agent=self.agent, workspace=self.workspace) + with pytest.raises(ConversationRunError) as exc_info: + conversation.run(poll_interval=0.01) + assert "stuck" in str(exc_info.value).lower() + + @patch( + "openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient" + ) + def test_remote_conversation_run_404_raises(self, mock_ws_client): + """Test that 404s during polling raise ConversationRunError.""" + conversation_id = str(uuid.uuid4()) + mock_client_instance = self.setup_mock_client(conversation_id=conversation_id) + + original_side_effect = mock_client_instance.request.side_effect + + def custom_side_effect(method, url, **kwargs): + if method == "GET" and url == f"/api/conversations/{conversation_id}": + request = httpx.Request("GET", f"http://localhost{url}") + return httpx.Response(404, request=request, text="Not Found") + return original_side_effect(method, url, **kwargs) + + mock_client_instance.request.side_effect = custom_side_effect + + mock_ws_instance = Mock() + mock_ws_client.return_value = mock_ws_instance + + conversation = RemoteConversation(agent=self.agent, workspace=self.workspace) + with pytest.raises(ConversationRunError) as exc_info: + conversation.run(poll_interval=0.01) + assert "not found" in str(exc_info.value).lower() + @patch( "openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient" )