|
3 | 3 | import sys |
4 | 4 | from pathlib import Path |
5 | 5 | from typing import Any |
6 | | -from unittest.mock import patch |
| 6 | +from unittest.mock import MagicMock, patch |
7 | 7 |
|
8 | 8 | import pytest |
9 | 9 | from pydantic import BaseModel, field_validator, model_validator |
|
15 | 15 | from resolve_model_config import ( # noqa: E402 # type: ignore[import-not-found] |
16 | 16 | MODELS, |
17 | 17 | find_models_by_id, |
| 18 | + run_preflight_check, |
| 19 | + test_model, |
18 | 20 | ) |
19 | 21 |
|
20 | 22 |
|
@@ -254,3 +256,195 @@ def test_glm_5_config(): |
254 | 256 | assert model["display_name"] == "GLM-5" |
255 | 257 | assert model["llm_config"]["model"] == "litellm_proxy/openrouter/z-ai/glm-5" |
256 | 258 | assert model["llm_config"]["disable_vision"] is True |
| 259 | + |
| 260 | + |
| 261 | +# Tests for preflight check functionality |
| 262 | + |
| 263 | + |
| 264 | +class TestTestModel: |
| 265 | + """Tests for the test_model function.""" |
| 266 | + |
| 267 | + def test_successful_response(self): |
| 268 | + """Test that a successful model response returns True.""" |
| 269 | + model_config = { |
| 270 | + "display_name": "Test Model", |
| 271 | + "llm_config": {"model": "litellm_proxy/test-model"}, |
| 272 | + } |
| 273 | + mock_response = MagicMock() |
| 274 | + mock_response.choices = [MagicMock(message=MagicMock(content="OK"))] |
| 275 | + |
| 276 | + with patch( |
| 277 | + "resolve_model_config.litellm.completion", return_value=mock_response |
| 278 | + ): |
| 279 | + success, message = test_model(model_config, "test-key", "https://test.com") |
| 280 | + |
| 281 | + assert success is True |
| 282 | + assert "✓" in message |
| 283 | + assert "Test Model" in message |
| 284 | + |
| 285 | + def test_empty_response(self): |
| 286 | + """Test that an empty response returns False.""" |
| 287 | + model_config = { |
| 288 | + "display_name": "Test Model", |
| 289 | + "llm_config": {"model": "litellm_proxy/test-model"}, |
| 290 | + } |
| 291 | + mock_response = MagicMock() |
| 292 | + mock_response.choices = [MagicMock(message=MagicMock(content=""))] |
| 293 | + |
| 294 | + with patch( |
| 295 | + "resolve_model_config.litellm.completion", return_value=mock_response |
| 296 | + ): |
| 297 | + success, message = test_model(model_config, "test-key", "https://test.com") |
| 298 | + |
| 299 | + assert success is False |
| 300 | + assert "✗" in message |
| 301 | + assert "Empty response" in message |
| 302 | + |
| 303 | + def test_timeout_error(self): |
| 304 | + """Test that timeout errors are handled correctly.""" |
| 305 | + import litellm |
| 306 | + |
| 307 | + model_config = { |
| 308 | + "display_name": "Test Model", |
| 309 | + "llm_config": {"model": "litellm_proxy/test-model"}, |
| 310 | + } |
| 311 | + |
| 312 | + with patch( |
| 313 | + "resolve_model_config.litellm.completion", |
| 314 | + side_effect=litellm.exceptions.Timeout( |
| 315 | + message="Timeout", model="test-model", llm_provider="test" |
| 316 | + ), |
| 317 | + ): |
| 318 | + success, message = test_model(model_config, "test-key", "https://test.com") |
| 319 | + |
| 320 | + assert success is False |
| 321 | + assert "✗" in message |
| 322 | + assert "timed out" in message |
| 323 | + |
| 324 | + def test_connection_error(self): |
| 325 | + """Test that connection errors are handled correctly.""" |
| 326 | + import litellm |
| 327 | + |
| 328 | + model_config = { |
| 329 | + "display_name": "Test Model", |
| 330 | + "llm_config": {"model": "litellm_proxy/test-model"}, |
| 331 | + } |
| 332 | + |
| 333 | + with patch( |
| 334 | + "resolve_model_config.litellm.completion", |
| 335 | + side_effect=litellm.exceptions.APIConnectionError( |
| 336 | + message="Connection failed", llm_provider="test", model="test-model" |
| 337 | + ), |
| 338 | + ): |
| 339 | + success, message = test_model(model_config, "test-key", "https://test.com") |
| 340 | + |
| 341 | + assert success is False |
| 342 | + assert "✗" in message |
| 343 | + assert "Connection error" in message |
| 344 | + |
| 345 | + def test_model_not_found_error(self): |
| 346 | + """Test that model not found errors are handled correctly.""" |
| 347 | + import litellm |
| 348 | + |
| 349 | + model_config = { |
| 350 | + "display_name": "Test Model", |
| 351 | + "llm_config": {"model": "litellm_proxy/test-model"}, |
| 352 | + } |
| 353 | + |
| 354 | + with patch( |
| 355 | + "resolve_model_config.litellm.completion", |
| 356 | + side_effect=litellm.exceptions.NotFoundError( |
| 357 | + "Model not found", llm_provider="test", model="test-model" |
| 358 | + ), |
| 359 | + ): |
| 360 | + success, message = test_model(model_config, "test-key", "https://test.com") |
| 361 | + |
| 362 | + assert success is False |
| 363 | + assert "✗" in message |
| 364 | + assert "not found" in message |
| 365 | + |
| 366 | + def test_passes_llm_config_params(self): |
| 367 | + """Test that llm_config parameters are passed to litellm.""" |
| 368 | + model_config = { |
| 369 | + "display_name": "Test Model", |
| 370 | + "llm_config": { |
| 371 | + "model": "litellm_proxy/test-model", |
| 372 | + "temperature": 0.5, |
| 373 | + "top_p": 0.9, |
| 374 | + }, |
| 375 | + } |
| 376 | + mock_response = MagicMock() |
| 377 | + mock_response.choices = [MagicMock(message=MagicMock(content="OK"))] |
| 378 | + |
| 379 | + with patch( |
| 380 | + "resolve_model_config.litellm.completion", return_value=mock_response |
| 381 | + ) as mock_completion: |
| 382 | + test_model(model_config, "test-key", "https://test.com") |
| 383 | + |
| 384 | + mock_completion.assert_called_once() |
| 385 | + call_kwargs = mock_completion.call_args[1] |
| 386 | + assert call_kwargs["temperature"] == 0.5 |
| 387 | + assert call_kwargs["top_p"] == 0.9 |
| 388 | + |
| 389 | + |
| 390 | +class TestRunPreflightCheck: |
| 391 | + """Tests for the run_preflight_check function.""" |
| 392 | + |
| 393 | + def test_skip_when_no_api_key(self): |
| 394 | + """Test that preflight check is skipped when LLM_API_KEY is not set.""" |
| 395 | + models = [{"display_name": "Test", "llm_config": {"model": "test"}}] |
| 396 | + |
| 397 | + with patch.dict("os.environ", {}, clear=True): |
| 398 | + result = run_preflight_check(models) |
| 399 | + |
| 400 | + assert result is True # Skipped = success |
| 401 | + |
| 402 | + def test_skip_when_skip_preflight_true(self): |
| 403 | + """Test that preflight check is skipped when SKIP_PREFLIGHT=true.""" |
| 404 | + models = [{"display_name": "Test", "llm_config": {"model": "test"}}] |
| 405 | + |
| 406 | + with patch.dict( |
| 407 | + "os.environ", {"LLM_API_KEY": "test", "SKIP_PREFLIGHT": "true"} |
| 408 | + ): |
| 409 | + result = run_preflight_check(models) |
| 410 | + |
| 411 | + assert result is True # Skipped = success |
| 412 | + |
| 413 | + def test_all_models_pass(self): |
| 414 | + """Test that preflight check returns True when all models pass.""" |
| 415 | + models = [ |
| 416 | + {"display_name": "Model A", "llm_config": {"model": "model-a"}}, |
| 417 | + {"display_name": "Model B", "llm_config": {"model": "model-b"}}, |
| 418 | + ] |
| 419 | + mock_response = MagicMock() |
| 420 | + mock_response.choices = [MagicMock(message=MagicMock(content="OK"))] |
| 421 | + |
| 422 | + with patch.dict("os.environ", {"LLM_API_KEY": "test"}): |
| 423 | + with patch( |
| 424 | + "resolve_model_config.litellm.completion", return_value=mock_response |
| 425 | + ): |
| 426 | + result = run_preflight_check(models) |
| 427 | + |
| 428 | + assert result is True |
| 429 | + |
| 430 | + def test_any_model_fails(self): |
| 431 | + """Test that preflight check returns False when any model fails.""" |
| 432 | + models = [ |
| 433 | + {"display_name": "Model A", "llm_config": {"model": "model-a"}}, |
| 434 | + {"display_name": "Model B", "llm_config": {"model": "model-b"}}, |
| 435 | + ] |
| 436 | + mock_response = MagicMock() |
| 437 | + mock_response.choices = [MagicMock(message=MagicMock(content="OK"))] |
| 438 | + |
| 439 | + def mock_completion(**kwargs): |
| 440 | + if kwargs["model"] == "model-b": |
| 441 | + raise Exception("Model B failed") |
| 442 | + return mock_response |
| 443 | + |
| 444 | + with patch.dict("os.environ", {"LLM_API_KEY": "test"}): |
| 445 | + with patch( |
| 446 | + "resolve_model_config.litellm.completion", side_effect=mock_completion |
| 447 | + ): |
| 448 | + result = run_preflight_check(models) |
| 449 | + |
| 450 | + assert result is False |
0 commit comments