Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,9 @@ def _wait_for_run_completion(
timeout: Maximum time in seconds to wait.

Raises:
ConversationRunError: If the wait times out.
ConversationRunError: If the run fails, the conversation disappears,
or the wait times out. Transient network errors, 429s, and 5xx
responses are retried until timeout.
"""
start_time = time.monotonic()

Expand All @@ -753,41 +755,77 @@ def _wait_for_run_completion(
)

try:
resp = _send_request(
self._client,
"GET",
f"/api/conversations/{self._id}",
timeout=30,
)
info = resp.json()
status = info.get("execution_status")

if status != ConversationExecutionStatus.RUNNING.value:
if status == ConversationExecutionStatus.ERROR.value:
detail = self._get_last_error_detail()
raise ConversationRunError(
self._id,
RuntimeError(
detail or "Remote conversation ended with error"
),
)
if status == ConversationExecutionStatus.STUCK.value:
raise ConversationRunError(
self._id,
RuntimeError("Remote conversation got stuck"),
)
status = self._poll_status_once()
except Exception as exc:
self._handle_poll_exception(exc)
else:
if self._handle_conversation_status(status):
logger.info(
f"Run completed with status: {status} (elapsed: {elapsed:.1f}s)"
"Run completed with status: %s (elapsed: %.1fs)",
status,
elapsed,
)
return

except Exception as e:
# Log but continue polling - transient network errors shouldn't
# stop us from waiting for the run to complete
logger.warning(f"Error polling status (will retry): {e}")

time.sleep(poll_interval)

def _poll_status_once(self) -> str | None:
"""Fetch the current execution status from the remote conversation."""
resp = _send_request(
self._client,
"GET",
f"/api/conversations/{self._id}",
timeout=30,
)
info = resp.json()
return info.get("execution_status")

def _handle_conversation_status(self, status: str | None) -> bool:
"""Handle non-running statuses; return True if the run is complete."""
if status == ConversationExecutionStatus.RUNNING.value:
return False
if status == ConversationExecutionStatus.ERROR.value:
detail = self._get_last_error_detail()
raise ConversationRunError(
self._id,
RuntimeError(detail or "Remote conversation ended with error"),
)
if status == ConversationExecutionStatus.STUCK.value:
raise ConversationRunError(
self._id,
RuntimeError("Remote conversation got stuck"),
)
return True

def _handle_poll_exception(self, exc: Exception) -> None:
"""Classify polling exceptions into retryable vs terminal failures."""
if isinstance(exc, httpx.HTTPStatusError):
status_code = exc.response.status_code
reason = exc.response.reason_phrase
if status_code == 404:
raise ConversationRunError(
self._id,
RuntimeError(
"Remote conversation not found (404). "
"The runtime may have been deleted."
),
) from exc
if 400 <= status_code < 500 and status_code != 429:
raise ConversationRunError(
self._id,
RuntimeError(f"Polling failed with HTTP {status_code} {reason}"),
) from exc
logger.warning(
"Error polling status (will retry): HTTP %d %s",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I see this right, we retry on all statuses except 400’s? Even on 500 and 500+
Ah, until timeout, right… makes sense to me 🤔

status_code,
reason,
)
return
if isinstance(exc, httpx.RequestError):
logger.warning(f"Error polling status (will retry): {exc}")
return
raise ConversationRunError(self._id, exc) from exc

def _get_last_error_detail(self) -> str | None:
"""Return the most recent ConversationErrorEvent detail, if available."""
events = self._state.events
Expand Down
89 changes: 89 additions & 0 deletions tests/sdk/conversation/remote/test_remote_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import SecretStr

from openhands.sdk.agent import Agent
from openhands.sdk.conversation.exceptions import ConversationRunError
from openhands.sdk.conversation.impl.remote_conversation import RemoteConversation
from openhands.sdk.conversation.secret_registry import SecretValue
from openhands.sdk.conversation.visualizer import DefaultConversationVisualizer
Expand Down Expand Up @@ -461,6 +462,94 @@ def custom_side_effect(method, url, **kwargs):
f"Should have polled 3 times (2 running + 1 finished), got {poll_count[0]}"
)

@patch(
"openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient"
)
def test_remote_conversation_run_error_status_raises(self, mock_ws_client):
"""Test that error status raises ConversationRunError."""
conversation_id = str(uuid.uuid4())
mock_client_instance = self.setup_mock_client(conversation_id=conversation_id)

original_side_effect = mock_client_instance.request.side_effect

def custom_side_effect(method, url, **kwargs):
if method == "GET" and url == f"/api/conversations/{conversation_id}":
response = Mock()
response.raise_for_status.return_value = None
response.json.return_value = {
"id": conversation_id,
"execution_status": "error",
}
return response
return original_side_effect(method, url, **kwargs)

mock_client_instance.request.side_effect = custom_side_effect

mock_ws_instance = Mock()
mock_ws_client.return_value = mock_ws_instance

conversation = RemoteConversation(agent=self.agent, workspace=self.workspace)
with pytest.raises(ConversationRunError) as exc_info:
conversation.run(poll_interval=0.01)
assert "error" in str(exc_info.value).lower()

@patch(
"openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient"
)
def test_remote_conversation_run_stuck_status_raises(self, mock_ws_client):
"""Test that stuck status raises ConversationRunError."""
conversation_id = str(uuid.uuid4())
mock_client_instance = self.setup_mock_client(conversation_id=conversation_id)

original_side_effect = mock_client_instance.request.side_effect

def custom_side_effect(method, url, **kwargs):
if method == "GET" and url == f"/api/conversations/{conversation_id}":
response = Mock()
response.raise_for_status.return_value = None
response.json.return_value = {
"id": conversation_id,
"execution_status": "stuck",
}
return response
return original_side_effect(method, url, **kwargs)

mock_client_instance.request.side_effect = custom_side_effect

mock_ws_instance = Mock()
mock_ws_client.return_value = mock_ws_instance

conversation = RemoteConversation(agent=self.agent, workspace=self.workspace)
with pytest.raises(ConversationRunError) as exc_info:
conversation.run(poll_interval=0.01)
assert "stuck" in str(exc_info.value).lower()

@patch(
"openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient"
)
def test_remote_conversation_run_404_raises(self, mock_ws_client):
"""Test that 404s during polling raise ConversationRunError."""
conversation_id = str(uuid.uuid4())
mock_client_instance = self.setup_mock_client(conversation_id=conversation_id)

original_side_effect = mock_client_instance.request.side_effect

def custom_side_effect(method, url, **kwargs):
if method == "GET" and url == f"/api/conversations/{conversation_id}":
request = httpx.Request("GET", f"http://localhost{url}")
return httpx.Response(404, request=request, text="Not Found")
return original_side_effect(method, url, **kwargs)

mock_client_instance.request.side_effect = custom_side_effect

mock_ws_instance = Mock()
mock_ws_client.return_value = mock_ws_instance

conversation = RemoteConversation(agent=self.agent, workspace=self.workspace)
with pytest.raises(ConversationRunError) as exc_info:
conversation.run(poll_interval=0.01)
assert "not found" in str(exc_info.value).lower()

@patch(
"openhands.sdk.conversation.impl.remote_conversation.WebSocketCallbackClient"
)
Expand Down
Loading