|
23 | 23 | from google.adk.agents.base_agent import BaseAgent |
24 | 24 | from google.adk.agents.context_cache_config import ContextCacheConfig |
25 | 25 | from google.adk.agents.invocation_context import InvocationContext |
| 26 | +from google.adk.agents.live_request_queue import LiveRequestQueue |
26 | 27 | from google.adk.agents.llm_agent import LlmAgent |
27 | 28 | from google.adk.agents.run_config import RunConfig |
28 | 29 | from google.adk.apps.app import App |
|
34 | 35 | from google.adk.runners import Runner |
35 | 36 | from google.adk.sessions.in_memory_session_service import InMemorySessionService |
36 | 37 | from google.adk.sessions.session import Session |
| 38 | +from google.adk.tools.function_tool import FunctionTool |
37 | 39 | from google.genai import types |
38 | 40 | import pytest |
39 | 41 |
|
@@ -358,6 +360,88 @@ async def test_run_live_auto_create_session(): |
358 | 360 | assert session is not None |
359 | 361 |
|
360 | 362 |
|
| 363 | +@pytest.mark.asyncio |
| 364 | +async def test_run_live_detects_streaming_tools_with_canonical_tools(): |
| 365 | + """run_live should detect streaming tools using canonical_tools and tool.name.""" |
| 366 | + |
| 367 | + # Define streaming tools - one as raw function, one wrapped in FunctionTool |
| 368 | + async def raw_streaming_tool( |
| 369 | + input_stream: LiveRequestQueue, |
| 370 | + ) -> AsyncGenerator[str, None]: |
| 371 | + """A raw streaming tool function.""" |
| 372 | + yield "test" |
| 373 | + |
| 374 | + async def wrapped_streaming_tool( |
| 375 | + input_stream: LiveRequestQueue, |
| 376 | + ) -> AsyncGenerator[str, None]: |
| 377 | + """A streaming tool wrapped in FunctionTool.""" |
| 378 | + yield "test" |
| 379 | + |
| 380 | + def non_streaming_tool(param: str) -> str: |
| 381 | + """A regular non-streaming tool.""" |
| 382 | + return param |
| 383 | + |
| 384 | + # Create a mock LlmAgent that yields an event and captures invocation context |
| 385 | + captured_context = {} |
| 386 | + |
| 387 | + class StreamingToolsAgent(LlmAgent): |
| 388 | + |
| 389 | + async def _run_live_impl( |
| 390 | + self, invocation_context: InvocationContext |
| 391 | + ) -> AsyncGenerator[Event, None]: |
| 392 | + # Capture the active_streaming_tools for verification |
| 393 | + captured_context["active_streaming_tools"] = ( |
| 394 | + invocation_context.active_streaming_tools |
| 395 | + ) |
| 396 | + yield Event( |
| 397 | + invocation_id=invocation_context.invocation_id, |
| 398 | + author=self.name, |
| 399 | + content=types.Content( |
| 400 | + role="model", parts=[types.Part(text="streaming test")] |
| 401 | + ), |
| 402 | + ) |
| 403 | + |
| 404 | + agent = StreamingToolsAgent( |
| 405 | + name="streaming_agent", |
| 406 | + model="gemini-2.0-flash", |
| 407 | + tools=[ |
| 408 | + raw_streaming_tool, # Raw function |
| 409 | + FunctionTool(wrapped_streaming_tool), # Wrapped in FunctionTool |
| 410 | + non_streaming_tool, # Non-streaming tool (should not be detected) |
| 411 | + ], |
| 412 | + ) |
| 413 | + |
| 414 | + session_service = InMemorySessionService() |
| 415 | + artifact_service = InMemoryArtifactService() |
| 416 | + runner = Runner( |
| 417 | + app_name="streaming_test_app", |
| 418 | + agent=agent, |
| 419 | + session_service=session_service, |
| 420 | + artifact_service=artifact_service, |
| 421 | + auto_create_session=True, |
| 422 | + ) |
| 423 | + |
| 424 | + live_queue = LiveRequestQueue() |
| 425 | + |
| 426 | + agen = runner.run_live( |
| 427 | + user_id="user", |
| 428 | + session_id="test_session", |
| 429 | + live_request_queue=live_queue, |
| 430 | + ) |
| 431 | + |
| 432 | + event = await agen.__anext__() |
| 433 | + await agen.aclose() |
| 434 | + |
| 435 | + assert event.author == "streaming_agent" |
| 436 | + |
| 437 | + # Verify streaming tools were detected correctly |
| 438 | + active_tools = captured_context.get("active_streaming_tools", {}) |
| 439 | + assert "raw_streaming_tool" in active_tools |
| 440 | + assert "wrapped_streaming_tool" in active_tools |
| 441 | + # Non-streaming tool should not be detected |
| 442 | + assert "non_streaming_tool" not in active_tools |
| 443 | + |
| 444 | + |
361 | 445 | @pytest.mark.asyncio |
362 | 446 | async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch): |
363 | 447 | project_root = tmp_path / "workspace" |
|
0 commit comments