Skip to content
This repository was archived by the owner on Feb 10, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions argo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,34 @@ def after_request(response):
}
}


def _extract_response_payload(response_obj):
"""Return the nested response payload if present."""
if isinstance(response_obj, dict) and "response" in response_obj:
return response_obj["response"]
return response_obj


def _extract_response_text(response_obj):
"""Extract textual content from an Argo response object."""
payload = _extract_response_payload(response_obj)

if isinstance(payload, dict):
content = payload.get("content")
if content is None:
return ""
if isinstance(content, (dict, list)):
try:
return json.dumps(content)
except TypeError:
return str(content)
return content

if payload is None:
return ""

return payload if isinstance(payload, str) else str(payload)

# Define which models use which environment
MODEL_ENV = {
# Models using production environment
Expand Down Expand Up @@ -395,7 +423,7 @@ def chat_completions():
return _proxy_argo_error_response(response, logger)

json_response = response.json()
text = json_response.get("response", "")
text = _extract_response_text(json_response)
log_data_verbose("Response text", text)

# Process tool calls in response if present
Expand Down Expand Up @@ -423,7 +451,7 @@ def chat_completions():
return _proxy_argo_error_response(response, logger)

json_response = response.json()
text = json_response.get("response", "")
text = _extract_response_text(json_response)
log_data_verbose("Response text", text)

# Process tool calls in response if present
Expand Down Expand Up @@ -608,10 +636,15 @@ def _static_chat_response_with_tools(text, model_base, json_response):
model_family = determine_model_family(model_base)

# Process response to extract tool calls
response_payload = _extract_response_payload(json_response)

tool_calls, clean_text = tool_interceptor.process(
json_response.get("response", text),
response_payload,
model_family
)

if not clean_text:
clean_text = text

# Determine finish reason
finish_reason = "tool_calls" if tool_calls else "stop"
Expand Down Expand Up @@ -652,10 +685,15 @@ def _fake_stream_response_with_tools(json_response, model, model_base):
model_family = determine_model_family(model_base)

# Process response to extract tool calls
response_payload = _extract_response_payload(json_response)

tool_calls, clean_text = tool_interceptor.process(
json_response,
response_payload,
model_family
)

if not clean_text:
clean_text = _extract_response_text(response_payload)

# Start with role chunk
begin_chunk = {
Expand Down Expand Up @@ -759,7 +797,7 @@ def _stream_chat_response_with_tools(model, req_obj, model_base):
return

json_response = response.json()
text = json_response.get("response", "")
text = _extract_response_text(json_response)

# Use fake streaming with tool processing
yield from _fake_stream_response_with_tools(json_response, model, model_base)
Expand Down Expand Up @@ -815,7 +853,7 @@ def completions():
return _proxy_argo_error_response(response, logger)

json_response = response.json()
text = json_response.get("response", "")
text = _extract_response_text(json_response)
log_data_verbose("Response text", text)

if is_streaming:
Expand Down
222 changes: 222 additions & 0 deletions tests/test_tool_calling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import json
import os
import sys

import pytest
import requests
from openai import OpenAI

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from tool_calls import ToolCall
from tool_calls.output_handle import (
ToolInterceptor,
tool_calls_to_openai,
tool_calls_to_openai_stream,
)

# Configuration
BRIDGE_URL = "http://localhost:7285" # Default argo_bridge URL
API_KEY = "dummy" # argo_bridge doesn't require real API keys
Expand Down Expand Up @@ -210,6 +222,216 @@ def test_conversation_with_tools(openai_client, mocker):
assert "Sunny" in final_message.content


def test_tool_interceptor_openai_nested_response():
interceptor = ToolInterceptor()
response_payload = {
"content": None,
"tool_calls": [
{
"id": "call_test",
"type": "function",
"function": {
"name": "add",
"arguments": "{\"a\":8,\"b\":5}"
}
}
]
}

tool_calls, clean_text = interceptor.process(response_payload, model_family="openai")

assert clean_text == ""
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].name == "add"
assert json.loads(tool_calls[0].arguments) == {"a": 8, "b": 5}


def test_tool_interceptor_google_object_response():
interceptor = ToolInterceptor()
response_payload = {
"content": None,
"tool_calls": {
"id": None,
"name": "add",
"args": {"a": 8, "b": 5}
}
}

tool_calls, clean_text = interceptor.process(response_payload, model_family="google")

assert clean_text == ""
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].name == "add"
assert json.loads(tool_calls[0].arguments) == {"a": 8, "b": 5}


def test_tool_interceptor_anthropic_with_text():
interceptor = ToolInterceptor()
response_payload = {
"response": {
"content": "I'll call the math tool now.",
"tool_calls": [
{
"id": "toolu_demo",
"name": "calculate",
"type": "tool_use",
"input": {"expression": "2+2"}
}
]
}
}

tool_calls, clean_text = interceptor.process(response_payload, model_family="anthropic")

assert clean_text == "I'll call the math tool now."
assert tool_calls is not None and len(tool_calls) == 1
assert tool_calls[0].name == "calculate"
assert json.loads(tool_calls[0].arguments) == {"expression": "2+2"}


def test_tool_interceptor_multiple_calls_mixed_content():
interceptor = ToolInterceptor()
response_payload = {
"response": {
"content": None,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location":"Paris"}'
}
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "calculate",
"arguments": '{"expression":"3*7"}'
}
}
]
}
}

tool_calls, clean_text = interceptor.process(response_payload, model_family="openai")

assert clean_text == ""
assert tool_calls is not None and len(tool_calls) == 2
assert [tc.name for tc in tool_calls] == ["get_weather", "calculate"]


def test_tool_interceptor_handles_missing_tool_calls():
interceptor = ToolInterceptor()
response_payload = {"response": {"content": "All done."}}

tool_calls, clean_text = interceptor.process(response_payload, model_family="openai")

assert clean_text == "All done."
assert tool_calls is None


def test_tool_interceptor_ignores_malformed_entry(caplog):
interceptor = ToolInterceptor()
caplog.set_level("WARNING")

response_payload = {
"response": {
"content": None,
"tool_calls": [
{
"id": "call_good",
"type": "function",
"function": {"name": "add", "arguments": "{\"a\":1}"}
},
{
"id": "call_bad",
"type": "function",
"function": "not-a-dict"
}
]
}
}

tool_calls, clean_text = interceptor.process(response_payload, model_family="openai")

assert clean_text == ""
assert tool_calls is not None and len(tool_calls) == 1
assert tool_calls[0].name == "add"
assert any("Failed" in message for message in caplog.messages)


@pytest.mark.parametrize(
"api_format",
["chat_completion", "response"],
)
def test_tool_calls_to_openai_conversion(api_format):
calls = [
ToolCall(id="call1", name="add", arguments="{\"a\":1}"),
ToolCall(id="call2", name="subtract", arguments="{\"a\":5}"),
]

converted = tool_calls_to_openai(calls, api_format=api_format)

assert len(converted) == 2
if api_format == "chat_completion":
assert converted[0].function.name == "add"
else:
assert converted[0].name == "add"


def test_tool_calls_to_openai_stream_conversion():
call = ToolCall(id="call1", name="add", arguments="{\"a\":1}")

result = tool_calls_to_openai_stream(call, tc_index=3)

assert result.index == 3
assert result.function.name == "add"


def test_tool_calls_to_openai_stream_invalid_type():
with pytest.raises(ValueError):
tool_calls_to_openai_stream(123, tc_index=0)


def test_tool_interceptor_google_non_dict_tool_calls():
interceptor = ToolInterceptor()
response_payload = {
"response": {
"content": None,
"tool_calls": [
{
"name": "lookup",
"args": {"value": "x"},
},
"unexpected"
]
}
}

tool_calls, clean_text = interceptor.process(response_payload, model_family="google")

assert clean_text == ""
assert tool_calls is not None and len(tool_calls) == 1
assert tool_calls[0].name == "lookup"


def test_tool_interceptor_prompt_based_parsing():
interceptor = ToolInterceptor()

text = "Pre text<tool_call>{\"name\":\"add\",\"arguments\":{\"a\":1}}</tool_call>post text"

tool_calls, clean_text = interceptor.process(text, model_family="openai")

assert clean_text.strip() == "Pre textpost text"
assert tool_calls is not None and len(tool_calls) == 1
assert tool_calls[0].name == "add"



def test_streaming_with_text_and_tool_call(openai_client, mocker):
"""Test streaming response with both text and a tool call."""
# Mock the streaming response
Expand Down
34 changes: 30 additions & 4 deletions tool_calls/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@
ToolChoiceToolParam,
ToolParam,
ToolUseBlock,
GeminiFunctionCall,
)
from .utils import API_FORMATS
from .utils import API_FORMATS, generate_id


class ToolCall(BaseModel):
Expand Down Expand Up @@ -120,8 +121,17 @@ def from_entry(
arguments=arguments_str,
)
elif api_format == "google":
# TODO: Implement Google API format
raise NotImplementedError("Google API format is not supported yet.")
origin_tool_call = GeminiFunctionCall.model_validate(tool_call)
arguments = (
json.dumps(origin_tool_call.args)
if not isinstance(origin_tool_call.args, str)
else origin_tool_call.args
)
return cls(
id=origin_tool_call.id or generate_id(),
name=origin_tool_call.name,
arguments=arguments,
)
else:
raise ValueError(f"Unsupported API format: {api_format}")

Expand Down Expand Up @@ -168,7 +178,23 @@ def to_tool_call(
)

elif api_format == "google":
raise NotImplementedError("Google API format is not supported yet.")
try:
parsed_args = (
json.loads(self.arguments)
if isinstance(self.arguments, str)
else self.arguments
)
except json.JSONDecodeError:
parsed_args = self.arguments

if not isinstance(parsed_args, dict):
parsed_args = {"value": parsed_args}

tool_call = GeminiFunctionCall(
id=self.id,
name=self.name,
args=parsed_args,
)

elif api_format == "general":
return self
Expand Down
Loading