Skip to content

Commit 33a9789

Browse files
authored
Merge branch 'main' into feat/allow-thinking-config-4108
2 parents cd2c510 + ec6abf4 commit 33a9789

File tree

2 files changed

+96
-5
lines changed

2 files changed

+96
-5
lines changed

src/google/adk/runners.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,12 +1015,15 @@ async def run_live(
10151015
# Pre-processing for live streaming tools
10161016
# Inspect the tool's parameters to find if it uses LiveRequestQueue
10171017
invocation_context.active_streaming_tools = {}
1018-
# TODO(hangfei): switch to use canonical_tools.
1019-
# for shell agents, there is no tools associated with it so we should skip.
1020-
if hasattr(invocation_context.agent, 'tools'):
1018+
# For shell agents, there is no canonical_tools method so we should skip.
1019+
if hasattr(invocation_context.agent, 'canonical_tools'):
10211020
import inspect
10221021

1023-
for tool in invocation_context.agent.tools:
1022+
# Use canonical_tools to get properly wrapped BaseTool instances
1023+
canonical_tools = await invocation_context.agent.canonical_tools(
1024+
invocation_context
1025+
)
1026+
for tool in canonical_tools:
10241027
# We use `inspect.signature()` to examine the tool's underlying function (`tool.func`).
10251028
# This approach is deliberately chosen over `typing.get_type_hints()` for robustness.
10261029
#
@@ -1044,10 +1047,14 @@ async def run_live(
10441047
if param.annotation is LiveRequestQueue:
10451048
if not invocation_context.active_streaming_tools:
10461049
invocation_context.active_streaming_tools = {}
1050+
1051+
logger.debug(
1052+
'Register streaming tool with input stream: %s', tool.name
1053+
)
10471054
active_streaming_tool = ActiveStreamingTool(
10481055
stream=LiveRequestQueue()
10491056
)
1050-
invocation_context.active_streaming_tools[tool.__name__] = (
1057+
invocation_context.active_streaming_tools[tool.name] = (
10511058
active_streaming_tool
10521059
)
10531060

tests/unittests/test_runners.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from google.adk.agents.base_agent import BaseAgent
2424
from google.adk.agents.context_cache_config import ContextCacheConfig
2525
from google.adk.agents.invocation_context import InvocationContext
26+
from google.adk.agents.live_request_queue import LiveRequestQueue
2627
from google.adk.agents.llm_agent import LlmAgent
2728
from google.adk.agents.run_config import RunConfig
2829
from google.adk.apps.app import App
@@ -34,6 +35,7 @@
3435
from google.adk.runners import Runner
3536
from google.adk.sessions.in_memory_session_service import InMemorySessionService
3637
from google.adk.sessions.session import Session
38+
from google.adk.tools.function_tool import FunctionTool
3739
from google.genai import types
3840
import pytest
3941

@@ -358,6 +360,88 @@ async def test_run_live_auto_create_session():
358360
assert session is not None
359361

360362

363+
@pytest.mark.asyncio
364+
async def test_run_live_detects_streaming_tools_with_canonical_tools():
365+
"""run_live should detect streaming tools using canonical_tools and tool.name."""
366+
367+
# Define streaming tools - one as raw function, one wrapped in FunctionTool
368+
async def raw_streaming_tool(
369+
input_stream: LiveRequestQueue,
370+
) -> AsyncGenerator[str, None]:
371+
"""A raw streaming tool function."""
372+
yield "test"
373+
374+
async def wrapped_streaming_tool(
375+
input_stream: LiveRequestQueue,
376+
) -> AsyncGenerator[str, None]:
377+
"""A streaming tool wrapped in FunctionTool."""
378+
yield "test"
379+
380+
def non_streaming_tool(param: str) -> str:
381+
"""A regular non-streaming tool."""
382+
return param
383+
384+
# Create a mock LlmAgent that yields an event and captures invocation context
385+
captured_context = {}
386+
387+
class StreamingToolsAgent(LlmAgent):
388+
389+
async def _run_live_impl(
390+
self, invocation_context: InvocationContext
391+
) -> AsyncGenerator[Event, None]:
392+
# Capture the active_streaming_tools for verification
393+
captured_context["active_streaming_tools"] = (
394+
invocation_context.active_streaming_tools
395+
)
396+
yield Event(
397+
invocation_id=invocation_context.invocation_id,
398+
author=self.name,
399+
content=types.Content(
400+
role="model", parts=[types.Part(text="streaming test")]
401+
),
402+
)
403+
404+
agent = StreamingToolsAgent(
405+
name="streaming_agent",
406+
model="gemini-2.0-flash",
407+
tools=[
408+
raw_streaming_tool, # Raw function
409+
FunctionTool(wrapped_streaming_tool), # Wrapped in FunctionTool
410+
non_streaming_tool, # Non-streaming tool (should not be detected)
411+
],
412+
)
413+
414+
session_service = InMemorySessionService()
415+
artifact_service = InMemoryArtifactService()
416+
runner = Runner(
417+
app_name="streaming_test_app",
418+
agent=agent,
419+
session_service=session_service,
420+
artifact_service=artifact_service,
421+
auto_create_session=True,
422+
)
423+
424+
live_queue = LiveRequestQueue()
425+
426+
agen = runner.run_live(
427+
user_id="user",
428+
session_id="test_session",
429+
live_request_queue=live_queue,
430+
)
431+
432+
event = await agen.__anext__()
433+
await agen.aclose()
434+
435+
assert event.author == "streaming_agent"
436+
437+
# Verify streaming tools were detected correctly
438+
active_tools = captured_context.get("active_streaming_tools", {})
439+
assert "raw_streaming_tool" in active_tools
440+
assert "wrapped_streaming_tool" in active_tools
441+
# Non-streaming tool should not be detected
442+
assert "non_streaming_tool" not in active_tools
443+
444+
361445
@pytest.mark.asyncio
362446
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
363447
project_root = tmp_path / "workspace"

0 commit comments

Comments
 (0)