Skip to content

Commit ceaa5d2

Browse files
fix: Use explicit mapping for include_default_tools deserialization
Replace ToolDefinition.resolve_kind() with an explicit BUILT_IN_TOOL_CLASSES mapping for deserializing include_default_tools. This is more reliable as resolve_kind() depends on subclass discovery which may fail if classes haven't been imported. Also add comprehensive tests for include_default_tools serialization and deserialization to ensure round-trip works correctly. Co-authored-by: openhands <[email protected]>
1 parent 5547cac commit ceaa5d2

File tree

2 files changed

+173
-7
lines changed

2 files changed

+173
-7
lines changed

openhands-sdk/openhands/sdk/agent/base.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,26 @@
2121
from openhands.sdk.llm.utils.model_prompt_spec import get_model_prompt_spec
2222
from openhands.sdk.logger import get_logger
2323
from openhands.sdk.mcp import create_mcp_tools
24-
from openhands.sdk.tool import BUILT_IN_TOOLS, Tool, ToolDefinition, resolve_tool
24+
from openhands.sdk.tool import (
25+
BUILT_IN_TOOLS,
26+
FinishTool,
27+
ThinkTool,
28+
Tool,
29+
ToolDefinition,
30+
resolve_tool,
31+
)
2532
from openhands.sdk.utils.models import DiscriminatedUnionMixin
2633

2734

35+
# Explicit mapping of built-in tool class names to their classes.
36+
# This is more reliable than using resolve_kind() which depends on
37+
# subclass discovery and may fail if classes haven't been imported.
38+
BUILT_IN_TOOL_CLASSES: dict[str, type[ToolDefinition]] = {
39+
"FinishTool": FinishTool,
40+
"ThinkTool": ThinkTool,
41+
}
42+
43+
2844
if TYPE_CHECKING:
2945
from openhands.sdk.conversation import ConversationState, LocalConversation
3046
from openhands.sdk.conversation.types import (
@@ -176,7 +192,7 @@ class AgentBase(DiscriminatedUnionMixin, ABC):
176192
@field_validator("include_default_tools", mode="before")
177193
@classmethod
178194
def _val_include_default_tools(cls, v: Any) -> list[str]:
179-
"""Convert tool classes to their class names if needed.
195+
"""Convert tool classes to their class names and validate them.
180196
181197
Accepts both strings and tool classes for backward compatibility.
182198
"""
@@ -185,12 +201,19 @@ def _val_include_default_tools(cls, v: Any) -> list[str]:
185201
result: list[str] = []
186202
for item in v:
187203
if isinstance(item, str):
188-
result.append(item)
204+
name = item
189205
elif isinstance(item, type):
190-
# It's a tool class
191-
result.append(item.__name__)
206+
# It's a tool class - convert to name
207+
name = item.__name__
192208
else:
193209
raise ValueError(f"Invalid item type: {type(item)}")
210+
# Validate that the tool name is a known built-in tool
211+
if name not in BUILT_IN_TOOL_CLASSES:
212+
raise ValueError(
213+
f"Unknown built-in tool class: '{name}'. "
214+
f"Expected one of: {list(BUILT_IN_TOOL_CLASSES.keys())}"
215+
)
216+
result.append(name)
194217
return result
195218

196219
@property
@@ -293,9 +316,14 @@ def _initialize(self, state: "ConversationState"):
293316
)
294317

295318
# Include default tools from include_default_tools; not subject to regex
296-
# filtering. Resolve tool class names and instantiate using their .create()
319+
# filtering. Use explicit mapping to resolve tool class names.
297320
for tool_name in self.include_default_tools:
298-
tool_class = ToolDefinition.resolve_kind(tool_name)
321+
tool_class = BUILT_IN_TOOL_CLASSES.get(tool_name)
322+
if tool_class is None:
323+
raise ValueError(
324+
f"Unknown built-in tool class: '{tool_name}'. "
325+
f"Expected one of: {list(BUILT_IN_TOOL_CLASSES.keys())}"
326+
)
299327
tool_instances = tool_class.create(state)
300328
tools.extend(tool_instances)
301329

tests/sdk/agent/test_agent_serialization.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from openhands.sdk.llm import LLM
1414
from openhands.sdk.mcp.client import MCPClient
1515
from openhands.sdk.mcp.tool import MCPToolDefinition
16+
from openhands.sdk.tool import FinishTool
1617
from openhands.sdk.tool.tool import ToolDefinition
1718
from openhands.sdk.utils.models import OpenHandsModel
1819

@@ -221,3 +222,140 @@ class TestModel(BaseModel):
221222
assert isinstance(deserialized_model.agent, Agent)
222223
assert deserialized_model.agent.model_dump() == agent.model_dump()
223224
assert deserialized_model.model_dump() == model.model_dump()
225+
226+
227+
def test_include_default_tools_serialization_default() -> None:
228+
"""Test that include_default_tools serializes correctly with default value."""
229+
llm = LLM(model="test-model", usage_id="test-llm")
230+
agent = Agent(llm=llm, tools=[])
231+
232+
# Serialize to JSON
233+
agent_json = agent.model_dump_json()
234+
agent_dict = json.loads(agent_json)
235+
236+
# Default should include both FinishTool and ThinkTool as strings
237+
assert "include_default_tools" in agent_dict
238+
assert set(agent_dict["include_default_tools"]) == {"FinishTool", "ThinkTool"}
239+
240+
241+
def test_include_default_tools_serialization_empty() -> None:
242+
"""Test that include_default_tools serializes correctly when empty."""
243+
llm = LLM(model="test-model", usage_id="test-llm")
244+
agent = Agent(llm=llm, tools=[], include_default_tools=[])
245+
246+
# Serialize to JSON
247+
agent_json = agent.model_dump_json()
248+
agent_dict = json.loads(agent_json)
249+
250+
# Should be empty list
251+
assert agent_dict["include_default_tools"] == []
252+
253+
254+
def test_include_default_tools_serialization_partial() -> None:
255+
"""Test that include_default_tools serializes correctly with partial list."""
256+
llm = LLM(model="test-model", usage_id="test-llm")
257+
agent = Agent(llm=llm, tools=[], include_default_tools=["FinishTool"])
258+
259+
# Serialize to JSON
260+
agent_json = agent.model_dump_json()
261+
agent_dict = json.loads(agent_json)
262+
263+
# Should be serialized as string
264+
assert agent_dict["include_default_tools"] == ["FinishTool"]
265+
266+
267+
def test_include_default_tools_deserialization_roundtrip() -> None:
268+
"""Test that include_default_tools deserializes correctly after round-trip."""
269+
llm = LLM(model="test-model", usage_id="test-llm")
270+
agent = Agent(llm=llm, tools=[], include_default_tools=["FinishTool"])
271+
272+
# Serialize to JSON
273+
agent_json = agent.model_dump_json()
274+
275+
# Deserialize from JSON
276+
deserialized_agent = AgentBase.model_validate_json(agent_json)
277+
278+
# Should have the same include_default_tools
279+
assert isinstance(deserialized_agent, Agent)
280+
assert deserialized_agent.include_default_tools == ["FinishTool"]
281+
282+
283+
def test_include_default_tools_deserialization_all_tools() -> None:
284+
"""Test that include_default_tools deserializes correctly with all tools."""
285+
llm = LLM(model="test-model", usage_id="test-llm")
286+
agent = Agent(llm=llm, tools=[], include_default_tools=["FinishTool", "ThinkTool"])
287+
288+
# Serialize to JSON
289+
agent_json = agent.model_dump_json()
290+
291+
# Deserialize from JSON
292+
deserialized_agent = AgentBase.model_validate_json(agent_json)
293+
294+
# Should have both tools
295+
assert isinstance(deserialized_agent, Agent)
296+
assert set(deserialized_agent.include_default_tools) == {"FinishTool", "ThinkTool"}
297+
298+
299+
def test_include_default_tools_deserialization_empty() -> None:
300+
"""Test that include_default_tools deserializes correctly when empty."""
301+
llm = LLM(model="test-model", usage_id="test-llm")
302+
agent = Agent(llm=llm, tools=[], include_default_tools=[])
303+
304+
# Serialize to JSON
305+
agent_json = agent.model_dump_json()
306+
307+
# Deserialize from JSON
308+
deserialized_agent = AgentBase.model_validate_json(agent_json)
309+
310+
# Should be empty
311+
assert isinstance(deserialized_agent, Agent)
312+
assert deserialized_agent.include_default_tools == []
313+
314+
315+
def test_include_default_tools_deserialization_from_dict() -> None:
316+
"""Test that include_default_tools deserializes correctly from dict."""
317+
agent_dict = {
318+
"llm": {"model": "test-model", "usage_id": "test-llm"},
319+
"tools": [],
320+
"include_default_tools": ["ThinkTool"],
321+
"kind": "Agent",
322+
}
323+
324+
# Deserialize from dict
325+
agent = AgentBase.model_validate(agent_dict)
326+
327+
# Should have ThinkTool
328+
assert isinstance(agent, Agent)
329+
assert agent.include_default_tools == ["ThinkTool"]
330+
331+
332+
def test_include_default_tools_invalid_tool_name() -> None:
333+
"""Test that include_default_tools raises error for invalid tool name."""
334+
agent_dict = {
335+
"llm": {"model": "test-model", "usage_id": "test-llm"},
336+
"tools": [],
337+
"include_default_tools": ["InvalidTool"],
338+
"kind": "Agent",
339+
}
340+
341+
# Should raise ValueError
342+
with pytest.raises(ValueError, match="Unknown built-in tool class: 'InvalidTool'"):
343+
AgentBase.model_validate(agent_dict)
344+
345+
346+
def test_include_default_tools_accepts_tool_class_via_validator() -> None:
347+
"""Test that include_default_tools validator converts tool classes to strings."""
348+
# The validator accepts tool classes and converts them to strings
349+
# This is tested via model_validate to bypass type checking
350+
agent_dict = {
351+
"llm": {"model": "test-model", "usage_id": "test-llm"},
352+
"tools": [],
353+
"include_default_tools": [FinishTool], # Pass class, not string
354+
"kind": "Agent",
355+
}
356+
357+
agent = AgentBase.model_validate(agent_dict)
358+
359+
# Should be converted to string
360+
assert isinstance(agent, Agent)
361+
assert agent.include_default_tools == ["FinishTool"]

0 commit comments

Comments
 (0)