|
16 | 16 | from http import HTTPStatus |
17 | 17 | from typing import Any, Dict, List, Union, Tuple |
18 | 18 | from pathlib import Path |
| 19 | +from unittest.mock import patch |
19 | 20 | import ray |
20 | 21 | import threading |
21 | 22 | import requests |
|
25 | 26 | from litellm import completion as litellm_completion |
26 | 27 | from litellm import acompletion as litellm_async_completion |
27 | 28 | from litellm import atext_completion as litellm_async_text_completion |
| 29 | +import logging |
28 | 30 |
|
29 | 31 | from skyrl_train.config import SkyRLConfig |
30 | 32 | from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient |
31 | 33 | from skyrl_train.inference_engines.base import ConversationType |
32 | 34 | from tests.gpu.utils import init_worker_with_type, get_test_prompts |
33 | 35 | from skyrl_train.inference_engines.utils import get_sampling_params_for_backend |
| 36 | +import skyrl_train.inference_engines.inference_engine_client_http_endpoint as http_endpoint_module |
34 | 37 | from skyrl_train.inference_engines.inference_engine_client_http_endpoint import ( |
35 | 38 | serve, |
36 | 39 | wait_for_server_ready, |
@@ -595,9 +598,12 @@ class TestSchema(BaseModel): |
595 | 598 |
|
596 | 599 | # TODO(Charlie): sglang has slightly different error response format. We need to handle it. |
597 | 600 | @pytest.mark.vllm |
598 | | -def test_http_endpoint_error_handling(ray_init_fixture): |
| 601 | +def test_http_endpoint_error_handling(ray_init_fixture, caplog): |
599 | 602 | """ |
600 | | - Test error handling for various invalid requests. |
| 603 | + Test error handling for various invalid requests and internal server errors. |
| 604 | +
|
| 605 | + Tests validation errors (400) for invalid requests and verifies that internal |
| 606 | + server errors (500) are logged with traceback server-side (not exposed to client). |
601 | 607 | """ |
602 | 608 | try: |
603 | 609 | cfg = get_test_actor_config(num_inference_engines=2, model=MODEL_QWEN2_5) |
@@ -718,6 +724,37 @@ def test_http_endpoint_error_handling(ray_init_fixture): |
718 | 724 | r = requests.post(f"{base_url}/v1/completions", json=bad_payload) |
719 | 725 | assert r.status_code == HTTPStatus.BAD_REQUEST |
720 | 726 |
|
| 727 | + # Test internal server errors (500) return proper error responses |
| 728 | + # Traceback is logged server-side only (not exposed to client per CWE-209) |
| 729 | + caplog.set_level(logging.ERROR) |
| 730 | + original_client = http_endpoint_module._global_inference_engine_client |
| 731 | + |
| 732 | + internal_error_cases = [ |
| 733 | + ( |
| 734 | + "chat_completion", |
| 735 | + "/v1/chat/completions", |
| 736 | + {"messages": [{"role": "user", "content": "Hello"}]}, |
| 737 | + KeyError("choices"), |
| 738 | + ), |
| 739 | + ("completion", "/v1/completions", {"prompt": "Hello"}, RuntimeError("Simulated internal error")), |
| 740 | + ] |
| 741 | + for method_name, endpoint, extra_payload, exception in internal_error_cases: |
| 742 | + |
| 743 | + async def mock_raises(*args, exc=exception, **kwargs): |
| 744 | + raise exc |
| 745 | + |
| 746 | + caplog.clear() |
| 747 | + with patch.object(original_client, method_name, side_effect=mock_raises): |
| 748 | + response = requests.post(f"{base_url}{endpoint}", json={"model": MODEL_QWEN2_5, **extra_payload}) |
| 749 | + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR |
| 750 | + error_data = response.json() |
| 751 | + error_message = error_data["error"]["message"] |
| 752 | + assert str(exception) in error_message or type(exception).__name__ in error_message |
| 753 | + assert "Traceback" not in error_message # Not exposed to client (CWE-209) |
| 754 | + assert error_data["error"]["code"] == 500 |
| 755 | + assert "Traceback (most recent call last):" in caplog.text # Logged server-side |
| 756 | + assert type(exception).__name__ in caplog.text |
| 757 | + |
721 | 758 | finally: |
722 | 759 | shutdown_server(host=SERVER_HOST, port=server_port, max_wait_seconds=5) |
723 | 760 | if server_thread.is_alive(): |
|
0 commit comments