Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,35 @@ async def test_async_bidi_stream_query(self):
events.append(event)
assert len(events) == 1

@pytest.mark.asyncio
async def test_async_bidi_stream_query_with_state(self):
app = reasoning_engines.AdkApp(
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
)
assert app._tmpl_attrs.get("runner") is None
app.set_up()
app._tmpl_attrs["runner"] = _MockRunner()
request_queue = asyncio.Queue()
request_dict = {
"user_id": _TEST_USER_ID,
"state": {"test_key": "test_val"},
"live_request": {
"input": "What is the exchange rate from USD to SEK?",
},
}

await request_queue.put(request_dict)
await request_queue.put(None) # Sentinel to end the stream.

with mock.patch.object(
app, "async_create_session", wraps=app.async_create_session
) as mock_create_session:
async for _ in app.bidi_stream_query(request_queue):
pass
mock_create_session.assert_called_once_with(
user_id=_TEST_USER_ID, state={"test_key": "test_val"}
)

def test_create_session(self):
app = reasoning_engines.AdkApp(
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
Expand Down
3 changes: 2 additions & 1 deletion vertexai/preview/reasoning_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,8 @@ async def bidi_stream_query(
if not self._tmpl_attrs.get("runner"):
self.set_up()
if not session_id:
session = await self.async_create_session(user_id=user_id)
state = first_request.get("state")
session = await self.async_create_session(user_id=user_id, state=state)
session_id = session.id
run_config = _validate_run_config(run_config)

Expand Down
Loading