Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions cli/serve/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""A simple app that runs an OpenAI compatible server wrapped around a M program."""

import asyncio
import importlib.util
import inspect
import os
import sys
import time
Expand All @@ -11,6 +13,8 @@
from fastapi import FastAPI
from fastapi.responses import JSONResponse

from mellea.backends.model_options import ModelOption

from .models import (
ChatCompletion,
ChatCompletionMessage,
Expand Down Expand Up @@ -53,20 +57,54 @@ def create_openai_error_response(
def make_chat_endpoint(module):
"""Makes a chat endpoint using a custom module."""

def _build_model_options(request: ChatCompletionRequest) -> dict:
"""Build model_options dict from OpenAI-compatible request parameters."""
excluded_fields = {
# Request structure fields (handled separately)
"messages", # Chat messages - passed separately to serve()
"requirements", # Mellea requirements - passed separately to serve()
# Routing/metadata fields (not generation parameters)
"model", # Model identifier - used for routing, not generation
"n", # Number of completions - not supported in Mellea's model_options
"user", # User tracking ID - metadata, not a generation parameter
"extra", # Pydantic's extra fields dict - unused (see model_config)
}
openai_to_model_option = {
"temperature": ModelOption.TEMPERATURE,
"max_tokens": ModelOption.MAX_NEW_TOKENS,
"seed": ModelOption.SEED,
}

filtered_options = {
key: value
for key, value in request.model_dump().items()
if key not in excluded_fields
}
return ModelOption.replace_keys(filtered_options, openai_to_model_option)

async def endpoint(request: ChatCompletionRequest):
try:
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
created_timestamp = int(time.time())

output = module.serve(
input=request.messages,
requirements=request.requirements,
model_options={
k: v
for k, v in request.model_dump().items()
if k not in ["messages", "requirements"]
},
)
model_options = _build_model_options(request)

# Detect if serve is async or sync and handle accordingly
if inspect.iscoroutinefunction(module.serve):
# It's async, await it directly
output = await module.serve(
input=request.messages,
requirements=request.requirements,
model_options=model_options,
)
else:
# It's sync, run in thread pool to avoid blocking event loop
output = await asyncio.to_thread(
module.serve,
input=request.messages,
requirements=request.requirements,
model_options=model_options,
)

# Extract usage information from the ModelOutputThunk if available
usage = None
Expand Down
4 changes: 4 additions & 0 deletions mellea/backends/model_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]:
# This will usually be a @@@<>@@@ ModelOption.<> key.
new_key = from_to.get(old_key, None)
if new_key:
# Skip if old_key and new_key are the same (no-op replacement)
if old_key == new_key:
continue

if new_options.get(new_key, None) is not None:
# The key already has a value associated with it in the dict. Leave it be.
conflict_log.append(
Expand Down
13 changes: 8 additions & 5 deletions test/cli/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ async def test_system_fingerprint_always_none(self, mock_module, sample_request)
@pytest.mark.asyncio
async def test_model_options_passed_correctly(self, mock_module, sample_request):
"""Test that model options are passed to serve function correctly."""
from mellea.backends.model_options import ModelOption

mock_output = ModelOutputThunk("Test response")
mock_module.serve.return_value = mock_output

Expand All @@ -134,11 +136,12 @@ async def test_model_options_passed_correctly(self, mock_module, sample_request)
assert "model_options" in call_args.kwargs
model_options = call_args.kwargs["model_options"]

# Should include temperature and max_tokens but not messages/requirements
assert "temperature" in model_options
assert model_options["temperature"] == 0.7
assert "max_tokens" in model_options
assert model_options["max_tokens"] == 100
# Should include ModelOption keys for temperature and max_tokens
# Note: TEMPERATURE is just "temperature" (not a sentinel), so it stays as-is
assert ModelOption.TEMPERATURE in model_options
assert model_options[ModelOption.TEMPERATURE] == 0.7
assert ModelOption.MAX_NEW_TOKENS in model_options
assert model_options[ModelOption.MAX_NEW_TOKENS] == 100
assert "messages" not in model_options
assert "requirements" not in model_options

Expand Down
255 changes: 255 additions & 0 deletions test/cli/test_serve_sync_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
"""Tests for sync/async serve function handling in m serve."""

import asyncio
from unittest.mock import AsyncMock, Mock

import pytest

from cli.serve.app import make_chat_endpoint
from cli.serve.models import ChatCompletionRequest, ChatMessage
from mellea.backends.model_options import ModelOption
from mellea.core import ModelOutputThunk


@pytest.fixture
def mock_sync_module():
"""Create a mock module with a synchronous serve function."""
module = Mock()
module.__name__ = "test_sync_module"

def sync_serve(input, requirements=None, model_options=None):
"""Synchronous serve function."""
# Simulate some work
return ModelOutputThunk(f"Sync response to: {input[-1].content}")

# Use Mock to wrap the function so we can track calls
module.serve = Mock(side_effect=sync_serve)
return module


@pytest.fixture
def mock_async_module():
"""Create a mock module with an asynchronous serve function."""
module = Mock()
module.__name__ = "test_async_module"

async def async_serve(input, requirements=None, model_options=None):
"""Asynchronous serve function."""
# Simulate async work
await asyncio.sleep(0.01)
return ModelOutputThunk(f"Async response to: {input[-1].content}")

module.serve = AsyncMock(side_effect=async_serve)
return module


@pytest.fixture
def mock_slow_sync_module():
"""Create a mock module with a slow synchronous serve function."""
module = Mock()
module.__name__ = "test_slow_sync_module"

def slow_sync_serve(input, requirements=None, model_options=None):
"""Slow synchronous serve function that would block event loop."""
import time

time.sleep(1) # Simulate blocking work with a clearer timing signal
return ModelOutputThunk(f"Slow sync response to: {input[-1].content}")

module.serve = slow_sync_serve
return module


class TestSyncAsyncServeHandling:
"""Test that serve handles both sync and async serve functions correctly."""

@pytest.mark.asyncio
async def test_sync_serve_function(self, mock_sync_module):
"""Test that synchronous serve functions work correctly."""
endpoint = make_chat_endpoint(mock_sync_module)

request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="Hello sync!")],
)

response = await endpoint(request)

assert response.choices[0].message.content == "Sync response to: Hello sync!"
assert response.model == "test-model"
assert response.object == "chat.completion"

@pytest.mark.asyncio
async def test_async_serve_function(self, mock_async_module):
"""Test that asynchronous serve functions work correctly."""
endpoint = make_chat_endpoint(mock_async_module)

request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="Hello async!")],
)

response = await endpoint(request)

assert response.choices[0].message.content == "Async response to: Hello async!"
assert response.model == "test-model"
assert response.object == "chat.completion"

@pytest.mark.asyncio
async def test_slow_sync_does_not_block(self, mock_slow_sync_module):
"""Test that slow sync functions run in thread pool and don't block event loop.

This test verifies non-blocking behavior by measuring timing. If the sync
function blocked the event loop, two sequential calls would take 2x the time.
With proper threading, they should overlap and take only slightly more than 1x.
"""
import time

endpoint = make_chat_endpoint(mock_slow_sync_module)

request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="Hello slow!")],
)

# Time two concurrent requests
start = time.time()
results = await asyncio.gather(endpoint(request), endpoint(request))
elapsed = time.time() - start

# If blocking: would take ~2s (1s + 1s sequentially)
# If non-blocking: should take ~1s (both run concurrently in threads)
# Allow some overhead, but should still be well below the blocking case.
assert elapsed < 2, (
f"Took {elapsed:.3f}s - appears to be blocking (expected ~1s)"
)
assert all(
r.choices[0].message.content == "Slow sync response to: Hello slow!"
for r in results
)

@pytest.mark.asyncio
async def test_concurrent_requests_with_sync_serve(self, mock_slow_sync_module):
"""Test that multiple sync requests can be handled concurrently."""
endpoint = make_chat_endpoint(mock_slow_sync_module)

requests = [
ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content=f"Request {i}")],
)
for i in range(3)
]

# Run requests concurrently
responses = await asyncio.gather(*[endpoint(req) for req in requests])

# All should complete successfully
assert len(responses) == 3
for i, response in enumerate(responses):
assert (
response.choices[0].message.content
== f"Slow sync response to: Request {i}"
)

@pytest.mark.asyncio
async def test_requirements_passed_to_serve(self, mock_sync_module):
"""Test that requirements are correctly passed to serve function."""
endpoint = make_chat_endpoint(mock_sync_module)

request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="Test")],
requirements=["req1", "req2"],
)

await endpoint(request)

# Verify serve was called with requirements
mock_sync_module.serve.assert_called_once()
call_kwargs = mock_sync_module.serve.call_args.kwargs
assert call_kwargs["requirements"] == ["req1", "req2"]

@pytest.mark.asyncio
async def test_model_options_passed_to_serve(self, mock_sync_module):
"""Test that model options are correctly passed to serve function."""
endpoint = make_chat_endpoint(mock_sync_module)

request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="Test")],
temperature=0.7,
max_tokens=100,
)

await endpoint(request)

# Verify serve was called with model_options
mock_sync_module.serve.assert_called_once()
call_kwargs = mock_sync_module.serve.call_args.kwargs
model_options = call_kwargs["model_options"]
assert ModelOption.TEMPERATURE in model_options
assert ModelOption.MAX_NEW_TOKENS in model_options

@pytest.mark.asyncio
async def test_openai_params_mapped_to_model_options(self, mock_sync_module):
"""Test that OpenAI parameters are mapped to ModelOption sentinels."""
endpoint = make_chat_endpoint(mock_sync_module)

request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="Test")],
temperature=0.8,
max_tokens=150,
seed=42,
)

await endpoint(request)

# Verify parameters are mapped correctly
mock_sync_module.serve.assert_called_once()
call_kwargs = mock_sync_module.serve.call_args.kwargs
model_options = call_kwargs["model_options"]

assert model_options[ModelOption.TEMPERATURE] == 0.8
assert model_options[ModelOption.MAX_NEW_TOKENS] == 150
assert model_options[ModelOption.SEED] == 42


class TestEndpointIntegration:
"""Integration tests for the full endpoint."""

def test_endpoint_name_set_correctly(self, mock_sync_module):
"""Test that endpoint function name is set correctly."""
endpoint = make_chat_endpoint(mock_sync_module)
assert endpoint.__name__ == "chat_test_sync_module_endpoint"

@pytest.mark.asyncio
async def test_completion_id_generated(self, mock_sync_module):
"""Test that each response gets a unique completion ID."""
endpoint = make_chat_endpoint(mock_sync_module)

request = ChatCompletionRequest(
model="test-model", messages=[ChatMessage(role="user", content="Test")]
)

response1 = await endpoint(request)
response2 = await endpoint(request)

assert response1.id.startswith("chatcmpl-")
assert response2.id.startswith("chatcmpl-")
assert response1.id != response2.id

@pytest.mark.asyncio
async def test_timestamp_generated(self, mock_sync_module):
"""Test that response includes a timestamp."""
endpoint = make_chat_endpoint(mock_sync_module)

request = ChatCompletionRequest(
model="test-model", messages=[ChatMessage(role="user", content="Test")]
)

response = await endpoint(request)

assert isinstance(response.created, int)
assert response.created > 0
Loading