diff --git a/argo_bridge.py b/argo_bridge.py index 5184b54..bd4cb65 100644 --- a/argo_bridge.py +++ b/argo_bridge.py @@ -122,6 +122,34 @@ def after_request(response): } } + +def _extract_response_payload(response_obj): + """Return the nested response payload if present.""" + if isinstance(response_obj, dict) and "response" in response_obj: + return response_obj["response"] + return response_obj + + +def _extract_response_text(response_obj): + """Extract textual content from an Argo response object.""" + payload = _extract_response_payload(response_obj) + + if isinstance(payload, dict): + content = payload.get("content") + if content is None: + return "" + if isinstance(content, (dict, list)): + try: + return json.dumps(content) + except TypeError: + return str(content) + return content + + if payload is None: + return "" + + return payload if isinstance(payload, str) else str(payload) + # Define which models use which environment MODEL_ENV = { # Models using production environment @@ -395,7 +423,7 @@ def chat_completions(): return _proxy_argo_error_response(response, logger) json_response = response.json() - text = json_response.get("response", "") + text = _extract_response_text(json_response) log_data_verbose("Response text", text) # Process tool calls in response if present @@ -423,7 +451,7 @@ def chat_completions(): return _proxy_argo_error_response(response, logger) json_response = response.json() - text = json_response.get("response", "") + text = _extract_response_text(json_response) log_data_verbose("Response text", text) # Process tool calls in response if present @@ -608,10 +636,15 @@ def _static_chat_response_with_tools(text, model_base, json_response): model_family = determine_model_family(model_base) # Process response to extract tool calls + response_payload = _extract_response_payload(json_response) + tool_calls, clean_text = tool_interceptor.process( - json_response.get("response", text), + response_payload, model_family ) + + if not clean_text: + clean_text = text # Determine finish reason finish_reason = "tool_calls" if tool_calls else "stop" @@ -652,10 +685,15 @@ def _fake_stream_response_with_tools(json_response, model, model_base): model_family = determine_model_family(model_base) # Process response to extract tool calls + response_payload = _extract_response_payload(json_response) + tool_calls, clean_text = tool_interceptor.process( - json_response, + response_payload, model_family ) + + if not clean_text: + clean_text = _extract_response_text(response_payload) # Start with role chunk begin_chunk = { @@ -759,7 +797,7 @@ def _stream_chat_response_with_tools(model, req_obj, model_base): return json_response = response.json() - text = json_response.get("response", "") + text = _extract_response_text(json_response) # Use fake streaming with tool processing yield from _fake_stream_response_with_tools(json_response, model, model_base) @@ -815,7 +853,7 @@ def completions(): return _proxy_argo_error_response(response, logger) json_response = response.json() - text = json_response.get("response", "") + text = _extract_response_text(json_response) log_data_verbose("Response text", text) if is_streaming: diff --git a/tests/test_tool_calling.py b/tests/test_tool_calling.py index 25922cc..5a61f11 100644 --- a/tests/test_tool_calling.py +++ b/tests/test_tool_calling.py @@ -1,8 +1,20 @@ import json +import os +import sys + import pytest import requests from openai import OpenAI +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from tool_calls import ToolCall +from tool_calls.output_handle import ( + ToolInterceptor, + tool_calls_to_openai, + tool_calls_to_openai_stream, +) + # Configuration BRIDGE_URL = "http://localhost:7285" # Default argo_bridge URL API_KEY = "dummy" # argo_bridge doesn't require real API keys @@ -210,6 +222,216 @@ def test_conversation_with_tools(openai_client, mocker): assert "Sunny" in final_message.content +def test_tool_interceptor_openai_nested_response(): + interceptor = ToolInterceptor() + response_payload = { + "content": None, + "tool_calls": [ + { + "id": "call_test", + "type": "function", + "function": { + "name": "add", + "arguments": "{\"a\":8,\"b\":5}" + } + } + ] + } + + tool_calls, clean_text = interceptor.process(response_payload, model_family="openai") + + assert clean_text == "" + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].name == "add" + assert json.loads(tool_calls[0].arguments) == {"a": 8, "b": 5} + + +def test_tool_interceptor_google_object_response(): + interceptor = ToolInterceptor() + response_payload = { + "content": None, + "tool_calls": { + "id": None, + "name": "add", + "args": {"a": 8, "b": 5} + } + } + + tool_calls, clean_text = interceptor.process(response_payload, model_family="google") + + assert clean_text == "" + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].name == "add" + assert json.loads(tool_calls[0].arguments) == {"a": 8, "b": 5} + + +def test_tool_interceptor_anthropic_with_text(): + interceptor = ToolInterceptor() + response_payload = { + "response": { + "content": "I'll call the math tool now.", + "tool_calls": [ + { + "id": "toolu_demo", + "name": "calculate", + "type": "tool_use", + "input": {"expression": "2+2"} + } + ] + } + } + + tool_calls, clean_text = interceptor.process(response_payload, model_family="anthropic") + + assert clean_text == "I'll call the math tool now." + assert tool_calls is not None and len(tool_calls) == 1 + assert tool_calls[0].name == "calculate" + assert json.loads(tool_calls[0].arguments) == {"expression": "2+2"} + + +def test_tool_interceptor_multiple_calls_mixed_content(): + interceptor = ToolInterceptor() + response_payload = { + "response": { + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location":"Paris"}' + } + }, + { + "id": "call_2", + "type": "function", + "function": { + "name": "calculate", + "arguments": '{"expression":"3*7"}' + } + } + ] + } + } + + tool_calls, clean_text = interceptor.process(response_payload, model_family="openai") + + assert clean_text == "" + assert tool_calls is not None and len(tool_calls) == 2 + assert [tc.name for tc in tool_calls] == ["get_weather", "calculate"] + + +def test_tool_interceptor_handles_missing_tool_calls(): + interceptor = ToolInterceptor() + response_payload = {"response": {"content": "All done."}} + + tool_calls, clean_text = interceptor.process(response_payload, model_family="openai") + + assert clean_text == "All done." + assert tool_calls is None + + +def test_tool_interceptor_ignores_malformed_entry(caplog): + interceptor = ToolInterceptor() + caplog.set_level("WARNING") + + response_payload = { + "response": { + "content": None, + "tool_calls": [ + { + "id": "call_good", + "type": "function", + "function": {"name": "add", "arguments": "{\"a\":1}"} + }, + { + "id": "call_bad", + "type": "function", + "function": "not-a-dict" + } + ] + } + } + + tool_calls, clean_text = interceptor.process(response_payload, model_family="openai") + + assert clean_text == "" + assert tool_calls is not None and len(tool_calls) == 1 + assert tool_calls[0].name == "add" + assert any("Failed" in message for message in caplog.messages) + + +@pytest.mark.parametrize( + "api_format", + ["chat_completion", "response"], +) +def test_tool_calls_to_openai_conversion(api_format): + calls = [ + ToolCall(id="call1", name="add", arguments="{\"a\":1}"), + ToolCall(id="call2", name="subtract", arguments="{\"a\":5}"), + ] + + converted = tool_calls_to_openai(calls, api_format=api_format) + + assert len(converted) == 2 + if api_format == "chat_completion": + assert converted[0].function.name == "add" + else: + assert converted[0].name == "add" + + +def test_tool_calls_to_openai_stream_conversion(): + call = ToolCall(id="call1", name="add", arguments="{\"a\":1}") + + result = tool_calls_to_openai_stream(call, tc_index=3) + + assert result.index == 3 + assert result.function.name == "add" + + +def test_tool_calls_to_openai_stream_invalid_type(): + with pytest.raises(ValueError): + tool_calls_to_openai_stream(123, tc_index=0) + + +def test_tool_interceptor_google_non_dict_tool_calls(): + interceptor = ToolInterceptor() + response_payload = { + "response": { + "content": None, + "tool_calls": [ + { + "name": "lookup", + "args": {"value": "x"}, + }, + "unexpected" + ] + } + } + + tool_calls, clean_text = interceptor.process(response_payload, model_family="google") + + assert clean_text == "" + assert tool_calls is not None and len(tool_calls) == 1 + assert tool_calls[0].name == "lookup" + + +def test_tool_interceptor_prompt_based_parsing(): + interceptor = ToolInterceptor() + + text = "Pre text{\"name\":\"add\",\"arguments\":{\"a\":1}}post text" + + tool_calls, clean_text = interceptor.process(text, model_family="openai") + + assert clean_text.strip() == "Pre textpost text" + assert tool_calls is not None and len(tool_calls) == 1 + assert tool_calls[0].name == "add" + + + def test_streaming_with_text_and_tool_call(openai_client, mocker): """Test streaming response with both text and a tool call.""" # Mock the streaming response diff --git a/tool_calls/handler.py b/tool_calls/handler.py index e37440e..dd699d2 100644 --- a/tool_calls/handler.py +++ b/tool_calls/handler.py @@ -48,8 +48,9 @@ ToolChoiceToolParam, ToolParam, ToolUseBlock, + GeminiFunctionCall, ) -from .utils import API_FORMATS +from .utils import API_FORMATS, generate_id class ToolCall(BaseModel): @@ -120,8 +121,17 @@ def from_entry( arguments=arguments_str, ) elif api_format == "google": - # TODO: Implement Google API format - raise NotImplementedError("Google API format is not supported yet.") + origin_tool_call = GeminiFunctionCall.model_validate(tool_call) + arguments = ( + json.dumps(origin_tool_call.args) + if not isinstance(origin_tool_call.args, str) + else origin_tool_call.args + ) + return cls( + id=origin_tool_call.id or generate_id(), + name=origin_tool_call.name, + arguments=arguments, + ) else: raise ValueError(f"Unsupported API format: {api_format}") @@ -168,7 +178,23 @@ def to_tool_call( ) elif api_format == "google": - raise NotImplementedError("Google API format is not supported yet.") + try: + parsed_args = ( + json.loads(self.arguments) + if isinstance(self.arguments, str) + else self.arguments + ) + except json.JSONDecodeError: + parsed_args = self.arguments + + if not isinstance(parsed_args, dict): + parsed_args = {"value": parsed_args} + + tool_call = GeminiFunctionCall( + id=self.id, + name=self.name, + args=parsed_args, + ) elif api_format == "general": return self diff --git a/tool_calls/output_handle.py b/tool_calls/output_handle.py index b9bead0..cb21945 100644 --- a/tool_calls/output_handle.py +++ b/tool_calls/output_handle.py @@ -164,20 +164,49 @@ def _process_openai_native( Returns: Tuple of (list of ToolCall objects or None, text content) """ - content = response_data.get("content", "") - tool_calls_data = response_data.get("tool_calls", []) + response_payload = response_data.get("response", response_data) + + if not isinstance(response_payload, dict): + logger.debug("OpenAI response payload is not a dict; returning string content") + content = response_payload if isinstance(response_payload, str) else str(response_payload) + return None, content + + content = response_payload.get("content") + if content is None: + content = "" + + tool_calls_data = response_payload.get("tool_calls", []) + + if isinstance(tool_calls_data, dict): + tool_calls_data = [tool_calls_data] + elif tool_calls_data is None: + tool_calls_data = [] + elif not isinstance(tool_calls_data, list): + logger.debug( + "Unexpected tool_calls payload type for OpenAI: %s", + type(tool_calls_data), + ) + tool_calls_data = [] # Convert tool calls to ToolCall objects tool_calls = None if tool_calls_data: tool_calls = [] for tool_call_dict in tool_calls_data: - # Use ToolCall.from_entry to convert from OpenAI format - tool_call = ToolCall.from_entry( - tool_call_dict, api_format="openai-chatcompletion" - ) + try: + tool_call = ToolCall.from_entry( + tool_call_dict, api_format="openai-chatcompletion" + ) + except Exception as exc: # noqa: BLE001 - allow broad catch for logging + logger.warning( + "Failed to parse OpenAI tool call entry: %s", exc, exc_info=True + ) + continue tool_calls.append(tool_call) + if not tool_calls: + tool_calls = None + return tool_calls, content def _process_anthropic_native( @@ -247,9 +276,42 @@ def _process_google_native( Returns: Tuple of (list of ToolCall objects or None, text content) """ - # Placeholder implementation - to be implemented later - logger.warning("Google native tool calling not implemented yet, falling back to OpenAI format") - raise NotImplementedError("Google native tool calling is not yet implemented. Please implement Google-specific tool calling format processing.") + response_payload = response_data.get("response", response_data) + + if not isinstance(response_payload, dict): + logger.debug("Google response payload is not a dict; returning string content") + content = response_payload if isinstance(response_payload, str) else str(response_payload) + return None, content + + text_content = response_payload.get("content") + if text_content is None: + text_content = "" + + gemini_tool_calls = response_payload.get("tool_calls") + + tool_calls = None + if gemini_tool_calls: + if isinstance(gemini_tool_calls, dict): + gemini_tool_calls = [gemini_tool_calls] + elif not isinstance(gemini_tool_calls, list): + logger.debug( + "Unexpected tool_calls payload type for Google: %s", + type(gemini_tool_calls), + ) + gemini_tool_calls = [gemini_tool_calls] + + tool_calls = [] + for gemini_tool_call in gemini_tool_calls: + try: + tool_call = ToolCall.from_entry(gemini_tool_call, api_format="google") + tool_calls.append(tool_call) + except Exception as exc: + logger.warning(f"Failed to parse Gemini tool call: {exc}") + + if not tool_calls: + tool_calls = None + + return tool_calls, text_content def chat_completion_to_response_tool_call( diff --git a/tool_calls/types.py b/tool_calls/types.py index 349be52..14e68cf 100644 --- a/tool_calls/types.py +++ b/tool_calls/types.py @@ -12,9 +12,9 @@ - Google Gemini Types (TODO) """ -from typing import Dict, List, Literal, Optional, TypeAlias, Union +from typing import Any, Dict, List, Literal, Optional, TypeAlias, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field # ====================================================================== # 1. OPENAI TYPES (CHAT COMPLETION & RESPONSES API) @@ -287,6 +287,17 @@ class ToolUseBlock(BaseModel): # ====================================================================== -# 3. GOOGLE GEMINI TYPES (TODO) +# 3. GOOGLE GEMINI TYPES # ====================================================================== -# Add Google Gemini-compatible function call types here... + + +class GeminiFunctionCall(BaseModel): + """Minimal representation of a Gemini function call.""" + + id: Optional[str] = None + name: str + args: Dict[str, Any] = Field(default_factory=dict) + type: Optional[str] = None + + model_config = ConfigDict(extra="allow") +