diff --git a/cli/serve/app.py b/cli/serve/app.py index df9eeab76..b4613b9cf 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -1,6 +1,8 @@ """A simple app that runs an OpenAI compatible server wrapped around a M program.""" +import asyncio import importlib.util +import inspect import os import sys import time @@ -8,9 +10,12 @@ import typer import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse +from mellea.backends.model_options import ModelOption + from .models import ( ChatCompletion, ChatCompletionMessage, @@ -28,6 +33,34 @@ ) +@app.exception_handler(RequestValidationError) +async def validation_exception_handler( + request: Request, exc: RequestValidationError +) -> JSONResponse: + """Convert FastAPI validation errors to OpenAI-compatible format. + + FastAPI returns 422 with a 'detail' array by default. OpenAI API uses + 400 with an 'error' object containing message, type, and param fields. + """ + # Extract the first validation error + errors = exc.errors() + if errors: + first_error = errors[0] + # Get the field name from the location tuple (e.g., ('body', 'n') -> 'n') + param = first_error["loc"][-1] if first_error["loc"] else None + message = first_error["msg"] + else: + param = None + message = "Invalid request parameters" + + return create_openai_error_response( + status_code=400, + message=message, + error_type="invalid_request_error", + param=str(param) if param else None, + ) + + def load_module_from_path(path: str): """Load the module with M program in it.""" module_name = os.path.splitext(os.path.basename(path))[0] @@ -50,23 +83,79 @@ def create_openai_error_response( ) +def _build_model_options(request: ChatCompletionRequest) -> dict: + """Build model_options dict from OpenAI-compatible request parameters.""" + excluded_fields = { + # Request structure fields (handled separately) + "messages", # Chat messages - passed separately to serve() + "requirements", # Mellea requirements - passed separately to serve() + # Routing/metadata fields (not generation parameters) + "model", # Model identifier - used for routing, not generation + "n", # Number of completions - not supported in Mellea's model_options + "user", # User tracking ID - metadata, not a generation parameter + "extra", # Pydantic's extra fields dict - unused (see model_config) + # Not-yet-implemented OpenAI parameters (silently ignored) + "stream", # Streaming responses - not yet implemented + "stop", # Stop sequences - not yet implemented + "top_p", # Nucleus sampling - not yet implemented + "presence_penalty", # Presence penalty - not yet implemented + "frequency_penalty", # Frequency penalty - not yet implemented + "logit_bias", # Logit bias - not yet implemented + "response_format", # Response format (json_object) - not yet implemented + "functions", # Legacy function calling - not yet implemented + "function_call", # Legacy function calling - not yet implemented + "tools", # Tool calling - not yet implemented + "tool_choice", # Tool choice - not yet implemented + } + openai_to_model_option = { + "temperature": ModelOption.TEMPERATURE, + "max_tokens": ModelOption.MAX_NEW_TOKENS, + "seed": ModelOption.SEED, + } + + filtered_options = { + key: value + for key, value in request.model_dump(exclude_none=True).items() + if key not in excluded_fields + } + return ModelOption.replace_keys(filtered_options, openai_to_model_option) + + def make_chat_endpoint(module): """Makes a chat endpoint using a custom module.""" async def endpoint(request: ChatCompletionRequest): try: + # Validate that n=1 (we don't support multiple completions) + if request.n is not None and request.n > 1: + return create_openai_error_response( + status_code=400, + message=f"Multiple completions (n={request.n}) are not supported. Please set n=1 or omit the parameter.", + error_type="invalid_request_error", + param="n", + ) + completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" created_timestamp = int(time.time()) - output = module.serve( - input=request.messages, - requirements=request.requirements, - model_options={ - k: v - for k, v in request.model_dump().items() - if k not in ["messages", "requirements"] - }, - ) + model_options = _build_model_options(request) + + # Detect if serve is async or sync and handle accordingly + if inspect.iscoroutinefunction(module.serve): + # It's async, await it directly + output = await module.serve( + input=request.messages, + requirements=request.requirements, + model_options=model_options, + ) + else: + # It's sync, run in thread pool to avoid blocking event loop + output = await asyncio.to_thread( + module.serve, + input=request.messages, + requirements=request.requirements, + model_options=model_options, + ) # Extract usage information from the ModelOutputThunk if available usage = None diff --git a/mellea/backends/model_options.py b/mellea/backends/model_options.py index 428dcab00..2350797ff 100644 --- a/mellea/backends/model_options.py +++ b/mellea/backends/model_options.py @@ -91,6 +91,10 @@ def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]: # This will usually be a @@@<>@@@ ModelOption.<> key. new_key = from_to.get(old_key, None) if new_key: + # Skip if old_key and new_key are the same (no-op replacement) + if old_key == new_key: + continue + if new_options.get(new_key, None) is not None: # The key already has a value associated with it in the dict. Leave it be. conflict_log.append( diff --git a/test/cli/test_build_model_options.py b/test/cli/test_build_model_options.py new file mode 100644 index 000000000..54702aea0 --- /dev/null +++ b/test/cli/test_build_model_options.py @@ -0,0 +1,108 @@ +"""Unit tests for _build_model_options function.""" + +import pytest + +from cli.serve.app import _build_model_options +from cli.serve.models import ChatCompletionRequest, ChatMessage +from mellea.backends.model_options import ModelOption + + +class TestBuildModelOptions: + """Direct unit tests for _build_model_options.""" + + def test_temperature_mapping(self): + """Test that temperature is correctly mapped to ModelOption.TEMPERATURE.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="test")], + temperature=0.7, + ) + options = _build_model_options(request) + assert options[ModelOption.TEMPERATURE] == 0.7 + + def test_max_tokens_mapping(self): + """Test that max_tokens is correctly mapped to ModelOption.MAX_NEW_TOKENS.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="test")], + max_tokens=100, + ) + options = _build_model_options(request) + assert options[ModelOption.MAX_NEW_TOKENS] == 100 + + def test_seed_mapping(self): + """Test that seed is correctly mapped to ModelOption.SEED.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="test")], + seed=42, + ) + options = _build_model_options(request) + assert options[ModelOption.SEED] == 42 + + def test_multiple_options(self): + """Test that multiple options are correctly mapped together.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="test")], + temperature=0.8, + max_tokens=200, + seed=123, + ) + options = _build_model_options(request) + assert options[ModelOption.TEMPERATURE] == 0.8 + assert options[ModelOption.MAX_NEW_TOKENS] == 200 + assert options[ModelOption.SEED] == 123 + + def test_excluded_fields_not_in_output(self): + """Test that excluded fields are not included in model_options.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="test")], + n=1, + user="test-user", + stream=False, + temperature=0.5, + ) + options = _build_model_options(request) + # Check that excluded fields are not present + assert "model" not in options + assert "messages" not in options + assert "n" not in options + assert "user" not in options + assert "stream" not in options + # Check that temperature is present + assert ModelOption.TEMPERATURE in options + + def test_none_values_excluded(self): + """Test that None values are excluded from output.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="test")], + temperature=None, + max_tokens=None, + ) + options = _build_model_options(request) + assert ModelOption.TEMPERATURE not in options + assert ModelOption.MAX_NEW_TOKENS not in options + + def test_minimal_request_includes_defaults(self): + """Test that a minimal request includes default values like temperature.""" + request = ChatCompletionRequest( + model="test-model", messages=[ChatMessage(role="user", content="test")] + ) + options = _build_model_options(request) + # ChatCompletionRequest has default temperature=1.0 + assert options == {ModelOption.TEMPERATURE: 1.0} + + def test_requirements_excluded(self): + """Test that requirements field is excluded from model_options.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="test")], + requirements=["req1", "req2"], + temperature=0.7, + ) + options = _build_model_options(request) + assert "requirements" not in options + assert ModelOption.TEMPERATURE in options diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 584be0876..515cc82f2 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -3,6 +3,9 @@ from unittest.mock import Mock import pytest +from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError +from fastapi.testclient import TestClient from cli.serve.app import make_chat_endpoint from cli.serve.models import ( @@ -122,6 +125,8 @@ async def test_system_fingerprint_always_none(self, mock_module, sample_request) @pytest.mark.asyncio async def test_model_options_passed_correctly(self, mock_module, sample_request): """Test that model options are passed to serve function correctly.""" + from mellea.backends.model_options import ModelOption + mock_output = ModelOutputThunk("Test response") mock_module.serve.return_value = mock_output @@ -134,11 +139,12 @@ async def test_model_options_passed_correctly(self, mock_module, sample_request) assert "model_options" in call_args.kwargs model_options = call_args.kwargs["model_options"] - # Should include temperature and max_tokens but not messages/requirements - assert "temperature" in model_options - assert model_options["temperature"] == 0.7 - assert "max_tokens" in model_options - assert model_options["max_tokens"] == 100 + # Should include ModelOption keys for temperature and max_tokens + # Note: TEMPERATURE is just "temperature" (not a sentinel), so it stays as-is + assert ModelOption.TEMPERATURE in model_options + assert model_options[ModelOption.TEMPERATURE] == 0.7 + assert ModelOption.MAX_NEW_TOKENS in model_options + assert model_options[ModelOption.MAX_NEW_TOKENS] == 100 assert "messages" not in model_options assert "requirements" not in model_options @@ -223,3 +229,309 @@ async def test_all_fields_together(self, mock_module, sample_request): assert response.system_fingerprint is None # Not tracking backend config assert response.object == "chat.completion" assert response.id.startswith("chatcmpl-") + + @pytest.mark.asyncio + async def test_n_greater_than_1_rejected(self, mock_module): + """Test that requests with n > 1 are rejected with appropriate error.""" + import json + + from fastapi.responses import JSONResponse + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + n=2, # Request multiple completions + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should return a JSONResponse error, not a ChatCompletion + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + # Decode the response body + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert "error" in error_data + assert error_data["error"]["type"] == "invalid_request_error" + assert error_data["error"]["param"] == "n" + assert "not supported" in error_data["error"]["message"].lower() + + # Verify serve was never called + mock_module.serve.assert_not_called() + + @pytest.mark.asyncio + async def test_n_equals_1_accepted(self, mock_module): + """Test that requests with n=1 are accepted.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + n=1, # Explicitly set to 1 + ) + + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should succeed + assert isinstance(response, ChatCompletion) + assert response.choices[0].message.content == "Test response" + + # Verify serve was called + mock_module.serve.assert_called_once() + + @pytest.mark.asyncio + async def test_n_less_than_1_rejected_by_pydantic(self, mock_module): + """Test that requests with n < 1 are rejected by Pydantic validation. + + FastAPI automatically validates request models before they reach the endpoint, + so n=0 or negative values will be caught by the framework, not our code. + This test documents that behavior. + """ + from pydantic import ValidationError + + # Pydantic validation happens before the endpoint is called + with pytest.raises(ValidationError) as exc_info: + ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + n=0, # Invalid: less than 1 + ) + + # Verify the error is about the 'n' field + errors = exc_info.value.errors() + assert len(errors) == 1 + assert errors[0]["loc"] == ("n",) + assert errors[0]["type"] == "greater_than_equal" + + +class TestHTTPValidation: + """Tests for HTTP-level validation via FastAPI TestClient.""" + + def test_n_zero_rejected_at_http_level(self, mock_module): + """Test that n=0 is rejected with OpenAI-compatible error format. + + Pydantic validation errors are caught by our custom exception handler + and converted to OpenAI-compatible 400 errors (not FastAPI's default 422). + """ + # Setup a test app with the exception handler + from cli.serve.app import validation_exception_handler + + app = FastAPI() + app.add_exception_handler(RequestValidationError, validation_exception_handler) + app.add_api_route( + "/v1/chat/completions", make_chat_endpoint(mock_module), methods=["POST"] + ) + client = TestClient(app) + + # Send request with n=0 + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "n": 0, + }, + ) + + # Our exception handler converts to OpenAI-compatible 400 error + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"]["type"] == "invalid_request_error" + assert error_data["error"]["param"] == "n" + # Pydantic's error message is used as-is + assert "greater than or equal to 1" in error_data["error"]["message"].lower() + + # Verify serve was never called + mock_module.serve.assert_not_called() + + def test_n_two_rejected_at_endpoint_level(self, mock_module): + """Test that n=2 is rejected by our endpoint logic (not Pydantic). + + While n=2 passes Pydantic validation (ge=1), our endpoint explicitly + rejects it because we don't support multiple completions. + """ + # Setup a test app + app = FastAPI() + app.add_api_route( + "/v1/chat/completions", make_chat_endpoint(mock_module), methods=["POST"] + ) + client = TestClient(app) + + # Send request with n=2 + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "n": 2, + }, + ) + + # Our endpoint returns 400 for unsupported n > 1 + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"]["type"] == "invalid_request_error" + assert error_data["error"]["param"] == "n" + assert "not supported" in error_data["error"]["message"].lower() + + # Verify serve was never called + mock_module.serve.assert_not_called() + + @pytest.mark.asyncio + async def test_n_none_accepted(self, mock_module): + """Test that requests with n=None (default) are accepted.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + # n not specified, defaults to 1 + ) + + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should succeed + assert isinstance(response, ChatCompletion) + assert response.choices[0].message.content == "Test response" + + # Verify serve was called + mock_module.serve.assert_called_once() + + @pytest.mark.asyncio + async def test_unsupported_params_excluded_from_model_options(self, mock_module): + """Test that unsupported OpenAI parameters are excluded from model_options.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + temperature=0.7, + max_tokens=100, + # Unsupported parameters that should be excluded + stream=False, + stop=["END"], + top_p=0.9, + presence_penalty=0.5, + frequency_penalty=0.3, + logit_bias={"123": 1.0}, + ) + + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should succeed + assert isinstance(response, ChatCompletion) + + # Verify serve was called with correct model_options + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + + # Supported parameters should be present + from mellea.backends.model_options import ModelOption + + assert ModelOption.TEMPERATURE in model_options + assert model_options[ModelOption.TEMPERATURE] == 0.7 + assert ModelOption.MAX_NEW_TOKENS in model_options + assert model_options[ModelOption.MAX_NEW_TOKENS] == 100 + + # Unsupported parameters should NOT be in model_options + assert "stream" not in model_options + assert "stop" not in model_options + assert "top_p" not in model_options + assert "presence_penalty" not in model_options + assert "frequency_penalty" not in model_options + assert "logit_bias" not in model_options + + @pytest.mark.asyncio + async def test_tool_params_excluded_from_model_options(self, mock_module): + """Test that tool-related parameters are excluded from model_options.""" + from cli.serve.models import ( + FunctionDefinition, + FunctionParameters, + ToolFunction, + ) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + # Tool-related parameters that should be excluded + tools=[ + ToolFunction( + type="function", + function=FunctionDefinition( + name="test_func", + description="A test function", + parameters=FunctionParameters(RootModel={"type": "object"}), + ), + ) + ], + tool_choice="auto", + functions=[ + FunctionDefinition( + name="legacy_func", + description="A legacy function", + parameters=FunctionParameters(RootModel={"type": "object"}), + ) + ], + function_call="auto", + ) + + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should succeed + assert isinstance(response, ChatCompletion) + + # Verify serve was called + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + + # Tool-related parameters should NOT be in model_options + assert "tools" not in model_options + assert "tool_choice" not in model_options + assert "functions" not in model_options + assert "function_call" not in model_options + + @pytest.mark.asyncio + async def test_response_format_excluded_from_model_options(self, mock_module): + """Test that response_format parameter is excluded from model_options.""" + from cli.serve.models import ResponseFormat + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + response_format=ResponseFormat(type="json_object"), + ) + + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should succeed + assert isinstance(response, ChatCompletion) + + # Verify serve was called + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + + # response_format should NOT be in model_options + assert "response_format" not in model_options diff --git a/test/cli/test_serve_sync_async.py b/test/cli/test_serve_sync_async.py new file mode 100644 index 000000000..8e0dab9f8 --- /dev/null +++ b/test/cli/test_serve_sync_async.py @@ -0,0 +1,255 @@ +"""Tests for sync/async serve function handling in m serve.""" + +import asyncio +from unittest.mock import AsyncMock, Mock + +import pytest + +from cli.serve.app import make_chat_endpoint +from cli.serve.models import ChatCompletionRequest, ChatMessage +from mellea.backends.model_options import ModelOption +from mellea.core import ModelOutputThunk + + +@pytest.fixture +def mock_sync_module(): + """Create a mock module with a synchronous serve function.""" + module = Mock() + module.__name__ = "test_sync_module" + + def sync_serve(input, requirements=None, model_options=None): + """Synchronous serve function.""" + # Simulate some work + return ModelOutputThunk(f"Sync response to: {input[-1].content}") + + # Use Mock to wrap the function so we can track calls + module.serve = Mock(side_effect=sync_serve) + return module + + +@pytest.fixture +def mock_async_module(): + """Create a mock module with an asynchronous serve function.""" + module = Mock() + module.__name__ = "test_async_module" + + async def async_serve(input, requirements=None, model_options=None): + """Asynchronous serve function.""" + # Simulate async work + await asyncio.sleep(0.01) + return ModelOutputThunk(f"Async response to: {input[-1].content}") + + module.serve = AsyncMock(side_effect=async_serve) + return module + + +@pytest.fixture +def mock_slow_sync_module(): + """Create a mock module with a slow synchronous serve function.""" + module = Mock() + module.__name__ = "test_slow_sync_module" + + def slow_sync_serve(input, requirements=None, model_options=None): + """Slow synchronous serve function that would block event loop.""" + import time + + time.sleep(1) # Simulate blocking work with a clearer timing signal + return ModelOutputThunk(f"Slow sync response to: {input[-1].content}") + + module.serve = slow_sync_serve + return module + + +class TestSyncAsyncServeHandling: + """Test that serve handles both sync and async serve functions correctly.""" + + @pytest.mark.asyncio + async def test_sync_serve_function(self, mock_sync_module): + """Test that synchronous serve functions work correctly.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello sync!")], + ) + + response = await endpoint(request) + + assert response.choices[0].message.content == "Sync response to: Hello sync!" + assert response.model == "test-model" + assert response.object == "chat.completion" + + @pytest.mark.asyncio + async def test_async_serve_function(self, mock_async_module): + """Test that asynchronous serve functions work correctly.""" + endpoint = make_chat_endpoint(mock_async_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello async!")], + ) + + response = await endpoint(request) + + assert response.choices[0].message.content == "Async response to: Hello async!" + assert response.model == "test-model" + assert response.object == "chat.completion" + + @pytest.mark.asyncio + async def test_slow_sync_does_not_block(self, mock_slow_sync_module): + """Test that slow sync functions run in thread pool and don't block event loop. + + This test verifies non-blocking behavior by measuring timing. If the sync + function blocked the event loop, two sequential calls would take 2x the time. + With proper threading, they should overlap and take only slightly more than 1x. + """ + import time + + endpoint = make_chat_endpoint(mock_slow_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello slow!")], + ) + + # Time two concurrent requests + start = time.time() + results = await asyncio.gather(endpoint(request), endpoint(request)) + elapsed = time.time() - start + + # If blocking: would take ~2s (1s + 1s sequentially) + # If non-blocking: should take ~1s (both run concurrently in threads) + # Allow some overhead, but should still be well below the blocking case. + assert elapsed < 2, ( + f"Took {elapsed:.3f}s - appears to be blocking (expected ~1s)" + ) + assert all( + r.choices[0].message.content == "Slow sync response to: Hello slow!" + for r in results + ) + + @pytest.mark.asyncio + async def test_concurrent_requests_with_sync_serve(self, mock_slow_sync_module): + """Test that multiple sync requests can be handled concurrently.""" + endpoint = make_chat_endpoint(mock_slow_sync_module) + + requests = [ + ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content=f"Request {i}")], + ) + for i in range(3) + ] + + # Run requests concurrently + responses = await asyncio.gather(*[endpoint(req) for req in requests]) + + # All should complete successfully + assert len(responses) == 3 + for i, response in enumerate(responses): + assert ( + response.choices[0].message.content + == f"Slow sync response to: Request {i}" + ) + + @pytest.mark.asyncio + async def test_requirements_passed_to_serve(self, mock_sync_module): + """Test that requirements are correctly passed to serve function.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Test")], + requirements=["req1", "req2"], + ) + + await endpoint(request) + + # Verify serve was called with requirements + mock_sync_module.serve.assert_called_once() + call_kwargs = mock_sync_module.serve.call_args.kwargs + assert call_kwargs["requirements"] == ["req1", "req2"] + + @pytest.mark.asyncio + async def test_model_options_passed_to_serve(self, mock_sync_module): + """Test that model options are correctly passed to serve function.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Test")], + temperature=0.7, + max_tokens=100, + ) + + await endpoint(request) + + # Verify serve was called with model_options + mock_sync_module.serve.assert_called_once() + call_kwargs = mock_sync_module.serve.call_args.kwargs + model_options = call_kwargs["model_options"] + assert ModelOption.TEMPERATURE in model_options + assert ModelOption.MAX_NEW_TOKENS in model_options + + @pytest.mark.asyncio + async def test_openai_params_mapped_to_model_options(self, mock_sync_module): + """Test that OpenAI parameters are mapped to ModelOption sentinels.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Test")], + temperature=0.8, + max_tokens=150, + seed=42, + ) + + await endpoint(request) + + # Verify parameters are mapped correctly + mock_sync_module.serve.assert_called_once() + call_kwargs = mock_sync_module.serve.call_args.kwargs + model_options = call_kwargs["model_options"] + + assert model_options[ModelOption.TEMPERATURE] == 0.8 + assert model_options[ModelOption.MAX_NEW_TOKENS] == 150 + assert model_options[ModelOption.SEED] == 42 + + +class TestEndpointIntegration: + """Integration tests for the full endpoint.""" + + def test_endpoint_name_set_correctly(self, mock_sync_module): + """Test that endpoint function name is set correctly.""" + endpoint = make_chat_endpoint(mock_sync_module) + assert endpoint.__name__ == "chat_test_sync_module_endpoint" + + @pytest.mark.asyncio + async def test_completion_id_generated(self, mock_sync_module): + """Test that each response gets a unique completion ID.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", messages=[ChatMessage(role="user", content="Test")] + ) + + response1 = await endpoint(request) + response2 = await endpoint(request) + + assert response1.id.startswith("chatcmpl-") + assert response2.id.startswith("chatcmpl-") + assert response1.id != response2.id + + @pytest.mark.asyncio + async def test_timestamp_generated(self, mock_sync_module): + """Test that response includes a timestamp.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", messages=[ChatMessage(role="user", content="Test")] + ) + + response = await endpoint(request) + + assert isinstance(response.created, int) + assert response.created > 0