Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Unreleased

- Support `extra_body` for Kimi, OpenAILegacy, OpenAIResponses, Anthropic chat providers
- Kimi thinking config now passed as `extra_body`

## [0.34.0] - 2025-12-19

- Support Vertex AI in GoogleGenAI chat provider
Expand Down
4 changes: 4 additions & 0 deletions src/kosong/chat_provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from kosong.message import ContentPart, Message, ToolCall, ToolCallPart
from kosong.tooling import Tool
from kosong.utils.typing import JsonType


@runtime_checkable
Expand Down Expand Up @@ -96,6 +97,9 @@ def input(self) -> int:
type ThinkingEffort = Literal["off", "low", "medium", "high"]
"""The effort level for thinking."""

type ExtraBody = dict[str, JsonType]
"""Arbitrary provider-specific request fields to be merged into the HTTP request body."""


class ChatProviderError(Exception):
"""The error raised by a chat provider."""
Expand Down
39 changes: 28 additions & 11 deletions src/kosong/chat_provider/kimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import uuid
from collections.abc import AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any, Self, TypedDict, Unpack, cast
from typing import TYPE_CHECKING, Any, Literal, Self, TypedDict, Unpack, cast

import httpx
from openai import AsyncOpenAI, AsyncStream, OpenAIError, omit
Expand All @@ -18,6 +18,7 @@
from kosong.chat_provider import (
ChatProvider,
ChatProviderError,
ExtraBody,
StreamedMessage,
StreamedMessagePart,
ThinkingEffort,
Expand All @@ -33,6 +34,9 @@ def type_check(kimi: "Kimi"):
_: ChatProvider = kimi


type KimiThinkingType = Literal["enabled", "disabled"]


class Kimi(ChatProvider):
"""
A chat provider that uses the Kimi API.
Expand All @@ -53,6 +57,10 @@ class Kimi(ChatProvider):
class GenerationKwargs(TypedDict, total=False):
"""
See https://platform.moonshot.ai/docs/api/chat#request-body.

Notes:
`thinking` is a Kimi-specific request field. It is injected into the request body via
OpenAI SDK's `extra_body`, because it is not part of the upstream OpenAI schema.
"""

max_tokens: int | None
Expand All @@ -63,7 +71,8 @@ class GenerationKwargs(TypedDict, total=False):
frequency_penalty: float | None
stop: str | list[str] | None
prompt_cache_key: str | None
reasoning_effort: str | None
thinking: KimiThinkingType | None
extra_body: ExtraBody | None

def __init__(
self,
Expand Down Expand Up @@ -115,20 +124,32 @@ async def generate(
"max_tokens": 32000,
}
generation_kwargs.update(self._generation_kwargs)

thinking: KimiThinkingType | None = generation_kwargs.pop("thinking", None)
extra_body: ExtraBody | None = generation_kwargs.pop("extra_body", None)

if "temperature" not in generation_kwargs:
thinking_enabled: bool = thinking == "enabled" or (
thinking is None and "kimi-k2-thinking" in self.model
)
# set default temperature based on model name
if "kimi-k2-thinking" in self.model or self._generation_kwargs.get("reasoning_effort"):
if thinking_enabled:
generation_kwargs["temperature"] = 1.0
elif "kimi-k2-" in self.model:
generation_kwargs["temperature"] = 0.6

try:
extra_body = dict(extra_body) if extra_body is not None else {}
if thinking is not None:
extra_body["thinking"] = {"type": thinking}

response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=(_convert_tool(tool) for tool in tools),
stream=self.stream,
stream_options={"include_usage": True} if self.stream else omit,
extra_body=extra_body,
**generation_kwargs,
)
return KimiStreamedMessage(response)
Expand All @@ -138,14 +159,10 @@ async def generate(
def with_thinking(self, effort: ThinkingEffort) -> Self:
match effort:
case "off":
reasoning_effort = None
case "low":
reasoning_effort = "low"
case "medium":
reasoning_effort = "medium"
case "high":
reasoning_effort = "high"
return self.with_generation_kwargs(reasoning_effort=reasoning_effort)
thinking: KimiThinkingType = "disabled"
case "low" | "medium" | "high":
thinking = "enabled"
return self.with_generation_kwargs(thinking=thinking)

def with_generation_kwargs(self, **kwargs: Unpack[GenerationKwargs]) -> Self:
"""
Expand Down
4 changes: 4 additions & 0 deletions src/kosong/contrib/chat_provider/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
APITimeoutError,
ChatProvider,
ChatProviderError,
ExtraBody,
StreamedMessagePart,
ThinkingEffort,
TokenUsage,
Expand Down Expand Up @@ -115,6 +116,7 @@ class GenerationKwargs(TypedDict, total=False):

beta_features: list[BetaFeatures] | None
extra_headers: Mapping[str, str] | None
extra_body: ExtraBody | None

def __init__(
self,
Expand Down Expand Up @@ -194,6 +196,7 @@ async def generate(
**{"anthropic-beta": ",".join(str(e) for e in betas)},
**(generation_kwargs.pop("extra_headers", {})),
}
extra_body: ExtraBody | None = generation_kwargs.pop("extra_body", None)

tools_ = [_convert_tool(tool) for tool in tools]
if tools:
Expand All @@ -206,6 +209,7 @@ async def generate(
tools=tools_,
stream=self._stream,
extra_headers=extra_headers,
extra_body=extra_body,
**generation_kwargs,
)
return AnthropicStreamedMessage(response)
Expand Down
11 changes: 10 additions & 1 deletion src/kosong/contrib/chat_provider/openai_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
)
from typing_extensions import TypedDict

from kosong.chat_provider import ChatProvider, StreamedMessagePart, ThinkingEffort, TokenUsage
from kosong.chat_provider import (
ChatProvider,
ExtraBody,
StreamedMessagePart,
ThinkingEffort,
TokenUsage,
)
from kosong.chat_provider.openai_common import (
convert_error,
thinking_effort_to_reasoning_effort,
Expand Down Expand Up @@ -57,6 +63,7 @@ class GenerationKwargs(TypedDict, extra_items=Any, total=False):
frequency_penalty: float | None
stop: str | list[str] | None
prompt_cache_key: str | None
extra_body: ExtraBody | None

def __init__(
self,
Expand Down Expand Up @@ -106,6 +113,7 @@ async def generate(

generation_kwargs: dict[str, Any] = {}
generation_kwargs.update(self._generation_kwargs)
extra_body: ExtraBody | None = generation_kwargs.pop("extra_body", None)

try:
response = await self.client.chat.completions.create(
Expand All @@ -115,6 +123,7 @@ async def generate(
stream=self.stream,
stream_options={"include_usage": True} if self.stream else omit,
reasoning_effort=self._reasoning_effort,
extra_body=extra_body,
**generation_kwargs,
)
return OpenAILegacyStreamedMessage(response, self._reasoning_key)
Expand Down
11 changes: 10 additions & 1 deletion src/kosong/contrib/chat_provider/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
from openai.types.shared.reasoning_effort import ReasoningEffort
from openai.types.shared_params.responses_model import ResponsesModel

from kosong.chat_provider import ChatProvider, StreamedMessagePart, ThinkingEffort, TokenUsage
from kosong.chat_provider import (
ChatProvider,
ExtraBody,
StreamedMessagePart,
ThinkingEffort,
TokenUsage,
)
from kosong.chat_provider.openai_common import convert_error, thinking_effort_to_reasoning_effort
from kosong.contrib.chat_provider.common import ToolMessageConversion
from kosong.message import (
Expand Down Expand Up @@ -101,6 +107,7 @@ class GenerationKwargs(TypedDict, total=False):
top_logprobs: float | None
top_p: float | None
user: str | None
extra_body: ExtraBody | None

def __init__(
self,
Expand Down Expand Up @@ -145,6 +152,7 @@ async def generate(

generation_kwargs: dict[str, Any] = {}
generation_kwargs.update(self._generation_kwargs)
extra_body: ExtraBody | None = generation_kwargs.pop("extra_body", None)
generation_kwargs["reasoning"] = Reasoning(
effort=generation_kwargs.pop("reasoning_effort", None),
summary="auto",
Expand All @@ -158,6 +166,7 @@ async def generate(
input=inputs,
tools=[_convert_tool(tool) for tool in tools],
store=False,
extra_body=extra_body,
**generation_kwargs,
)
return OpenAIResponsesStreamedMessage(response)
Expand Down
17 changes: 17 additions & 0 deletions tests/api_snapshot_tests/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,23 @@ async def test_anthropic_generation_kwargs():
)


@pytest.mark.asyncio
async def test_anthropic_extra_body():
with respx.mock(base_url="https://api.anthropic.com") as mock:
mock.post("/v1/messages").mock(return_value=Response(200, json=make_anthropic_response()))
provider = Anthropic(
model="claude-sonnet-4-20250514",
api_key="test-key",
default_max_tokens=1024,
stream=False,
).with_generation_kwargs(extra_body={"metadata": {"trace_id": "t-123"}})
stream = await provider.generate("", [], [Message(role="user", content="Hi")])
async for _ in stream:
pass
body = json.loads(mock.calls.last.request.content.decode())
assert body["metadata"] == snapshot({"trace_id": "t-123"})


@pytest.mark.asyncio
async def test_anthropic_with_thinking():
with respx.mock(base_url="https://api.anthropic.com") as mock:
Expand Down
22 changes: 21 additions & 1 deletion tests/api_snapshot_tests/test_kimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,24 @@ async def test_kimi_with_thinking():
async for _ in stream:
pass
body = json.loads(mock.calls.last.request.content.decode())
assert (body["reasoning_effort"], body["temperature"]) == snapshot(("high", 1.0))
assert (body["thinking"]["type"], body["temperature"]) == snapshot(("enabled", 1.0))


@pytest.mark.asyncio
async def test_kimi_with_thinking_merges_extra_body():
with respx.mock(base_url="https://api.moonshot.ai") as mock:
mock.post("/v1/chat/completions").mock(
return_value=Response(200, json=make_chat_completion_response())
)
provider = (
Kimi(model="kimi-k2-turbo-preview", api_key="test-key", stream=False)
.with_generation_kwargs(extra_body={"user_metadata": {"foo": "bar"}})
.with_thinking("high")
)
stream = await provider.generate("", [], [Message(role="user", content="Think")])
async for _ in stream:
pass
body = json.loads(mock.calls.last.request.content.decode())
assert (body["user_metadata"], body["thinking"]["type"]) == snapshot(
({"foo": "bar"}, "enabled")
)
16 changes: 16 additions & 0 deletions tests/api_snapshot_tests/test_openai_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,22 @@ async def test_openai_legacy_generation_kwargs():
assert (body["temperature"], body["max_tokens"]) == snapshot((0.7, 2048))


@pytest.mark.asyncio
async def test_openai_legacy_extra_body():
with respx.mock(base_url="https://api.openai.com") as mock:
mock.post("/v1/chat/completions").mock(
return_value=Response(200, json=make_chat_completion_response())
)
provider = OpenAILegacy(
model="gpt-4.1", api_key="test-key", stream=False
).with_generation_kwargs(extra_body={"metadata": {"trace_id": "t-123"}})
stream = await provider.generate("", [], [Message(role="user", content="Hi")])
async for _ in stream:
pass
body = json.loads(mock.calls.last.request.content.decode())
assert body["metadata"] == snapshot({"trace_id": "t-123"})


@pytest.mark.asyncio
async def test_openai_legacy_with_thinking():
with respx.mock(base_url="https://api.openai.com") as mock:
Expand Down
14 changes: 14 additions & 0 deletions tests/api_snapshot_tests/test_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,20 @@ async def test_openai_responses_generation_kwargs():
assert (body["temperature"], body["max_output_tokens"]) == snapshot((0.7, 2048))


@pytest.mark.asyncio
async def test_openai_responses_extra_body():
with respx.mock(base_url="https://api.openai.com") as mock:
mock.post("/v1/responses").mock(return_value=Response(200, json=make_response()))
provider = OpenAIResponses(
model="gpt-4.1", api_key="test-key", stream=False
).with_generation_kwargs(extra_body={"metadata": {"trace_id": "t-123"}})
stream = await provider.generate("", [], [Message(role="user", content="Hi")])
async for _ in stream:
pass
body = json.loads(mock.calls.last.request.content.decode())
assert body["metadata"] == snapshot({"trace_id": "t-123"})


@pytest.mark.asyncio
async def test_openai_responses_with_thinking():
with respx.mock(base_url="https://api.openai.com") as mock:
Expand Down