|
| 1 | +from typing import Annotated |
| 2 | + |
1 | 3 | import pytest |
2 | 4 | from chatlas import ChatOpenAI |
3 | 5 | from chatlas._content import ToolInfo |
@@ -117,6 +119,31 @@ def add(x: int, y: int) -> int: |
117 | 119 | assert props["x"]["description"] == "First number" # type: ignore |
118 | 120 | assert props["y"]["description"] == "Second number" # type: ignore |
119 | 121 |
|
| 122 | + def test_from_func_with_annotated_model(self): |
| 123 | + """Test creating a Tool with a model using Annotated fields.""" |
| 124 | + |
| 125 | + class AddParams(BaseModel): |
| 126 | + """Parameters for adding numbers.""" |
| 127 | + |
| 128 | + x: Annotated[int, Field(description="First number", ge=0)] |
| 129 | + y: Annotated[int, Field(description="Second number", le=100)] |
| 130 | + |
| 131 | + def add(x: int, y: int) -> int: |
| 132 | + return x + y |
| 133 | + |
| 134 | + tool = Tool.from_func(add, model=AddParams) |
| 135 | + |
| 136 | + assert tool.name == "AddParams" |
| 137 | + func = tool.schema["function"] |
| 138 | + |
| 139 | + # Check that Annotated Field descriptions and constraints are preserved |
| 140 | + params = func.get("parameters", {}) |
| 141 | + props = params["properties"] |
| 142 | + assert props["x"]["description"] == "First number" |
| 143 | + assert props["x"]["minimum"] == 0 |
| 144 | + assert props["y"]["description"] == "Second number" |
| 145 | + assert props["y"]["maximum"] == 100 |
| 146 | + |
120 | 147 | def test_from_func_with_model_missing_default_error(self): |
121 | 148 | """Test that error is raised when function has default but model doesn't. |
122 | 149 |
|
@@ -214,6 +241,170 @@ async def async_add(x: int, y: int) -> int: |
214 | 241 | assert func.get("description") == "Add two numbers asynchronously." |
215 | 242 |
|
216 | 243 |
|
| 244 | +class TestAnnotatedParameters: |
| 245 | + """Test support for typing.Annotated with pydantic.Field for parameter descriptions.""" |
| 246 | + |
| 247 | + def test_annotated_field_descriptions(self): |
| 248 | + """Test that Field descriptions in Annotated types are extracted.""" |
| 249 | + |
| 250 | + def add_numbers( |
| 251 | + x: Annotated[int, Field(description="The first number to be added")], |
| 252 | + y: Annotated[int, Field(description="The second number to be added")], |
| 253 | + ) -> int: |
| 254 | + """Add two numbers""" |
| 255 | + return x + y |
| 256 | + |
| 257 | + tool = Tool.from_func(add_numbers) |
| 258 | + |
| 259 | + assert tool.name == "add_numbers" |
| 260 | + func = tool.schema["function"] |
| 261 | + assert func.get("description") == "Add two numbers" |
| 262 | + |
| 263 | + params = func.get("parameters", {}) |
| 264 | + props = params["properties"] |
| 265 | + assert props["x"]["description"] == "The first number to be added" |
| 266 | + assert props["y"]["description"] == "The second number to be added" |
| 267 | + assert props["x"]["type"] == "integer" |
| 268 | + assert props["y"]["type"] == "integer" |
| 269 | + |
| 270 | + def test_annotated_with_default_value(self): |
| 271 | + """Test Annotated parameters with default values in function signature.""" |
| 272 | + |
| 273 | + def greet( |
| 274 | + name: Annotated[str, Field(description="Name to greet")], |
| 275 | + greeting: Annotated[str, Field(description="Greeting phrase")] = "Hello", |
| 276 | + ) -> str: |
| 277 | + """Generate a greeting""" |
| 278 | + return f"{greeting}, {name}!" |
| 279 | + |
| 280 | + tool = Tool.from_func(greet) |
| 281 | + func = tool.schema["function"] |
| 282 | + params = func.get("parameters", {}) |
| 283 | + |
| 284 | + # Check descriptions are preserved |
| 285 | + props = params["properties"] |
| 286 | + assert props["name"]["description"] == "Name to greet" |
| 287 | + assert props["greeting"]["description"] == "Greeting phrase" |
| 288 | + # Default value is preserved in schema |
| 289 | + assert props["greeting"]["default"] == "Hello" |
| 290 | + |
| 291 | + def test_annotated_with_field_default(self): |
| 292 | + """Test Annotated parameters with default in Field (not function signature).""" |
| 293 | + |
| 294 | + def process( |
| 295 | + value: Annotated[int, Field(description="Value to process", default=42)], |
| 296 | + ) -> int: |
| 297 | + """Process a value""" |
| 298 | + return value * 2 |
| 299 | + |
| 300 | + tool = Tool.from_func(process) |
| 301 | + func = tool.schema["function"] |
| 302 | + params = func.get("parameters", {}) |
| 303 | + |
| 304 | + props = params["properties"] |
| 305 | + assert props["value"]["description"] == "Value to process" |
| 306 | + assert props["value"]["default"] == 42 |
| 307 | + |
| 308 | + def test_annotated_function_default_overrides_field_default(self): |
| 309 | + """Test that function signature default takes precedence over Field default.""" |
| 310 | + |
| 311 | + def example( |
| 312 | + x: Annotated[int, Field(description="A number", default=10)] = 20, |
| 313 | + ) -> int: |
| 314 | + """Example function""" |
| 315 | + return x |
| 316 | + |
| 317 | + tool = Tool.from_func(example) |
| 318 | + func = tool.schema["function"] |
| 319 | + params = func.get("parameters", {}) |
| 320 | + |
| 321 | + props = params["properties"] |
| 322 | + # Function signature default (20) should override Field default (10) |
| 323 | + assert props["x"]["default"] == 20 |
| 324 | + |
| 325 | + def test_mixed_annotated_and_regular_parameters(self): |
| 326 | + """Test functions with both Annotated and regular parameters.""" |
| 327 | + |
| 328 | + def mixed_func( |
| 329 | + described: Annotated[str, Field(description="A described parameter")], |
| 330 | + plain: int, |
| 331 | + ) -> str: |
| 332 | + """Function with mixed parameter styles""" |
| 333 | + return f"{described}: {plain}" |
| 334 | + |
| 335 | + tool = Tool.from_func(mixed_func) |
| 336 | + func = tool.schema["function"] |
| 337 | + params = func.get("parameters", {}) |
| 338 | + props = params["properties"] |
| 339 | + |
| 340 | + # Annotated param should have description |
| 341 | + assert props["described"]["description"] == "A described parameter" |
| 342 | + |
| 343 | + # Plain param should not have description |
| 344 | + assert "description" not in props["plain"] |
| 345 | + |
| 346 | + def test_annotated_with_underscore_prefix(self): |
| 347 | + """Test Annotated parameters with underscore prefix (private-style names).""" |
| 348 | + |
| 349 | + def func_with_private( |
| 350 | + _private: Annotated[int, Field(description="A private-style param")], |
| 351 | + ) -> int: |
| 352 | + """Function with underscore-prefixed param""" |
| 353 | + return _private |
| 354 | + |
| 355 | + tool = Tool.from_func(func_with_private) |
| 356 | + func = tool.schema["function"] |
| 357 | + params = func.get("parameters", {}) |
| 358 | + props = params["properties"] |
| 359 | + |
| 360 | + # Schema uses the alias (_private) as the property key |
| 361 | + assert "_private" in props |
| 362 | + assert props["_private"]["description"] == "A private-style param" |
| 363 | + |
| 364 | + def test_annotated_registration_via_chat(self): |
| 365 | + """Test that Annotated tools work when registered via Chat.register_tool().""" |
| 366 | + chat = ChatOpenAI() |
| 367 | + |
| 368 | + def add_numbers( |
| 369 | + x: Annotated[int, Field(description="The first number")], |
| 370 | + y: Annotated[int, Field(description="The second number")], |
| 371 | + ) -> int: |
| 372 | + """Add two numbers""" |
| 373 | + return x + y |
| 374 | + |
| 375 | + chat.register_tool(add_numbers) |
| 376 | + |
| 377 | + tools = chat.get_tools() |
| 378 | + assert len(tools) == 1 |
| 379 | + |
| 380 | + tool = tools[0] |
| 381 | + func = tool.schema["function"] |
| 382 | + params = func.get("parameters", {}) |
| 383 | + props = params["properties"] |
| 384 | + |
| 385 | + assert props["x"]["description"] == "The first number" |
| 386 | + assert props["y"]["description"] == "The second number" |
| 387 | + |
| 388 | + def test_annotated_with_complex_types(self): |
| 389 | + """Test Annotated with more complex types.""" |
| 390 | + from typing import Optional |
| 391 | + |
| 392 | + def search( |
| 393 | + query: Annotated[str, Field(description="Search query string")], |
| 394 | + limit: Annotated[Optional[int], Field(description="Maximum results")] = None, |
| 395 | + ) -> str: |
| 396 | + """Search for items""" |
| 397 | + return f"Searching: {query}" |
| 398 | + |
| 399 | + tool = Tool.from_func(search) |
| 400 | + func = tool.schema["function"] |
| 401 | + params = func.get("parameters", {}) |
| 402 | + props = params["properties"] |
| 403 | + |
| 404 | + assert props["query"]["description"] == "Search query string" |
| 405 | + assert props["limit"]["description"] == "Maximum results" |
| 406 | + |
| 407 | + |
217 | 408 | class TestChatGetSetTools: |
218 | 409 | """Test Chat.get_tools() and Chat.set_tools() methods.""" |
219 | 410 |
|
|
0 commit comments