Skip to content

Commit 1725929

Browse files
authored
prevent tool cancellation when AgentTask is called inside it (#4586)
1 parent f689be6 commit 1725929

File tree

4 files changed

+102
-109
lines changed

4 files changed

+102
-109
lines changed

examples/voice_agents/email_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ async def register_for_event(self, context: RunContext):
5555
async def entrypoint(ctx: JobContext):
5656
session = AgentSession(
5757
vad=silero.VAD.load(),
58-
llm=inference.LLM("google/gemini-2.5-flash"),
58+
llm=inference.LLM("openai/gpt-4.1-mini"),
5959
stt=inference.STT("deepgram/nova-3"),
6060
tts=inference.TTS("cartesia/sonic-3"),
6161
)

livekit-agents/livekit/agents/voice/agent.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,18 @@ def _handle_task_done(_: asyncio.Task[Any]) -> None:
756756
old_agent = old_activity.agent
757757
session = old_activity.session
758758

759+
old_allow_interruptions = True
760+
if speech_handle:
761+
if speech_handle.interrupted:
762+
raise RuntimeError(
763+
f"{self.__class__.__name__} cannot be awaited inside a function tool that is already interrupted"
764+
)
765+
766+
# lock the speech handle to prevent interruptions until the task is complete
767+
# there should be no await before this line to avoid race conditions
768+
old_allow_interruptions = speech_handle.allow_interruptions
769+
speech_handle.allow_interruptions = False
770+
759771
blocked_tasks = [current_task]
760772
if (
761773
old_activity._on_enter_task
@@ -790,6 +802,9 @@ def _handle_task_done(_: asyncio.Task[Any]) -> None:
790802
return await asyncio.shield(self.__fut)
791803

792804
finally:
805+
if speech_handle:
806+
speech_handle.allow_interruptions = old_allow_interruptions
807+
793808
# run_state could have changed after self.__fut
794809
run_state = session._global_run_state
795810

livekit-agents/livekit/agents/voice/agent_activity.py

Lines changed: 82 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def __init__(self, agent: Agent, sess: AgentSession) -> None:
123123
# for false interruption handling
124124
self._paused_speech: SpeechHandle | None = None
125125
self._false_interruption_timer: asyncio.TimerHandle | None = None
126-
self._interrupt_paused_speech_task: asyncio.Task[None] | None = None
126+
self._cancel_speech_pause_task: asyncio.Task[None] | None = None
127+
127128
self._stt_eos_received: bool = False
128129

129130
# fired when a speech_task finishes or when a new speech_handle is scheduled
@@ -754,8 +755,11 @@ async def _close_session(self) -> None:
754755
*(mcp_server.aclose() for mcp_server in self.mcp_servers), return_exceptions=True
755756
)
756757

757-
await self._interrupt_paused_speech(old_task=self._interrupt_paused_speech_task)
758-
self._interrupt_paused_speech_task = None
758+
await self._cancel_speech_pause(
759+
old_task=self._cancel_speech_pause_task,
760+
interrupt=False, # don't interrupt the paused speech, it's managed by _pause_scheduling_task
761+
)
762+
self._cancel_speech_pause_task = None
759763

760764
async def aclose(self) -> None:
761765
# `aclose` must only be called by AgentSession
@@ -1371,8 +1375,8 @@ def on_final_transcript(self, ev: stt.SpeechEvent, *, speaking: bool | None = No
13711375
# schedule a resume timer if interrupted after end_of_speech
13721376
self._start_false_interruption_timer(timeout)
13731377

1374-
self._interrupt_paused_speech_task = asyncio.create_task(
1375-
self._interrupt_paused_speech(old_task=self._interrupt_paused_speech_task)
1378+
self._cancel_speech_pause_task = asyncio.create_task(
1379+
self._cancel_speech_pause(old_task=self._cancel_speech_pause_task)
13761380
)
13771381

13781382
def on_preemptive_generation(self, info: _PreemptiveGenerationInfo) -> None:
@@ -1490,7 +1494,7 @@ async def _user_turn_completed_task(
14901494
extra={"user_input": info.new_transcript},
14911495
)
14921496
return
1493-
await self._interrupt_paused_speech(self._interrupt_paused_speech_task)
1497+
await self._cancel_speech_pause(self._cancel_speech_pause_task)
14941498

14951499
await current_speech.interrupt()
14961500

@@ -2079,20 +2083,16 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None:
20792083
)
20802084

20812085
current_span.set_attribute(trace_types.ATTR_SPEECH_INTERRUPTED, speech_handle.interrupted)
2082-
has_speech_message = False
20832086

20842087
# add the tools messages that triggers this reply to the chat context
20852088
if _previous_tools_messages:
20862089
self._agent._chat_ctx.insert(_previous_tools_messages)
20872090
self._session._tool_items_added(_previous_tools_messages)
20882091

2092+
forwarded_text = text_out.text if text_out else ""
20892093
if speech_handle.interrupted:
20902094
await utils.aio.cancel_and_wait(*tasks)
2091-
await text_tee.aclose()
20922095

2093-
forwarded_text = text_out.text if text_out else ""
2094-
if forwarded_text:
2095-
has_speech_message = True
20962096
# if the audio playout was enabled, clear the buffer
20972097
if audio_output is not None:
20982098
audio_output.clear_buffer()
@@ -2109,55 +2109,39 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None:
21092109
else:
21102110
forwarded_text = ""
21112111

2112-
if forwarded_text:
2113-
msg = chat_ctx.add_message(
2114-
role="assistant",
2115-
content=forwarded_text,
2116-
id=llm_gen_data.id,
2117-
interrupted=True,
2118-
created_at=reply_started_at,
2119-
metrics=assistant_metrics,
2120-
)
2121-
self._agent._chat_ctx.insert(msg)
2122-
self._session._conversation_item_added(msg)
2123-
speech_handle._item_added([msg])
2124-
current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text)
2125-
2126-
if self._session.agent_state == "speaking":
2127-
self._session._update_agent_state("listening")
2128-
2129-
speech_handle._mark_generation_done()
2130-
await utils.aio.cancel_and_wait(exe_task)
2131-
return
2132-
2133-
if read_transcript_from_tts and text_out and not text_out.text:
2112+
elif read_transcript_from_tts and text_out and not text_out.text:
21342113
logger.warning(
21352114
"`use_tts_aligned_transcript` is enabled but no agent transcript was returned from tts"
21362115
)
21372116

2138-
if text_out and text_out.text:
2139-
has_speech_message = True
2117+
if forwarded_text:
21402118
msg = chat_ctx.add_message(
21412119
role="assistant",
2142-
content=text_out.text,
2120+
content=forwarded_text,
21432121
id=llm_gen_data.id,
2144-
interrupted=False,
2122+
interrupted=speech_handle.interrupted,
21452123
created_at=reply_started_at,
21462124
metrics=assistant_metrics,
21472125
)
21482126
self._agent._chat_ctx.insert(msg)
21492127
self._session._conversation_item_added(msg)
21502128
speech_handle._item_added([msg])
2151-
current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, text_out.text)
2129+
current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text)
21522130

2153-
if len(tool_output.output) > 0:
2131+
if not speech_handle.interrupted and len(tool_output.output) > 0:
21542132
self._session._update_agent_state("thinking")
21552133
elif self._session.agent_state == "speaking":
21562134
self._session._update_agent_state("listening")
21572135

21582136
await text_tee.aclose()
21592137

21602138
speech_handle._mark_generation_done() # mark the playout done before waiting for the tool execution # noqa: E501
2139+
2140+
if speech_handle.interrupted:
2141+
await utils.aio.cancel_and_wait(exe_task)
2142+
return
2143+
2144+
# wait for the tool execution to complete
21612145
self._background_speeches.add(speech_handle)
21622146
try:
21632147
await exe_task
@@ -2229,7 +2213,7 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None:
22292213
),
22302214
# in case the current reply only generated tools (no speech), re-use the current user_metrics for the next
22312215
# tool response generation
2232-
_previous_user_metrics=user_metrics if not has_speech_message else None,
2216+
_previous_user_metrics=user_metrics if not forwarded_text else None,
22332217
_previous_tools_messages=tool_messages,
22342218
),
22352219
speech_handle=speech_handle,
@@ -2580,83 +2564,66 @@ def _create_assistant_message(
25802564
msg.metrics = assistant_metrics
25812565
return msg
25822566

2567+
msg_gen, text_out, audio_out = (
2568+
message_outputs[0] if len(message_outputs) > 0 else (None, None, None)
2569+
) # there should be only one message
2570+
2571+
forwarded_text = text_out.text if text_out else ""
25832572
if speech_handle.interrupted:
25842573
await utils.aio.cancel_and_wait(*tasks)
25852574

2586-
if len(message_outputs) > 0:
2587-
# there should be only one message
2588-
msg_gen, text_out, audio_out = message_outputs[0]
2589-
forwarded_text = text_out.text if text_out else ""
2590-
if audio_output is not None:
2591-
audio_output.clear_buffer()
2575+
if msg_gen and audio_output is not None:
2576+
audio_output.clear_buffer()
25922577

2593-
playback_ev = await audio_output.wait_for_playout()
2594-
playback_position = playback_ev.playback_position
2595-
if (
2596-
audio_out is not None
2597-
and audio_out.first_frame_fut.done()
2598-
and not audio_out.first_frame_fut.cancelled()
2599-
):
2600-
# playback_ev is valid only if the first frame was already played
2601-
if playback_ev.synchronized_transcript is not None:
2602-
forwarded_text = playback_ev.synchronized_transcript
2603-
else:
2604-
forwarded_text = ""
2605-
playback_position = 0
2606-
2607-
# truncate server-side message (if supported)
2608-
if self.llm.capabilities.message_truncation:
2609-
msg_modalities = await msg_gen.modalities
2610-
self._rt_session.truncate(
2611-
message_id=msg_gen.message_id,
2612-
modalities=msg_modalities,
2613-
audio_end_ms=int(playback_position * 1000),
2614-
audio_transcript=forwarded_text,
2615-
)
2578+
playback_ev = await audio_output.wait_for_playout()
2579+
playback_position = playback_ev.playback_position
2580+
if (
2581+
audio_out is not None
2582+
and audio_out.first_frame_fut.done()
2583+
and not audio_out.first_frame_fut.cancelled()
2584+
):
2585+
# playback_ev is valid only if the first frame was already played
2586+
if playback_ev.synchronized_transcript is not None:
2587+
forwarded_text = playback_ev.synchronized_transcript
2588+
else:
2589+
forwarded_text = ""
2590+
playback_position = 0
26162591

2617-
msg: llm.ChatMessage | None = None
2618-
if forwarded_text:
2619-
msg = _create_assistant_message(
2592+
# truncate server-side message (if supported)
2593+
if self.llm.capabilities.message_truncation:
2594+
msg_modalities = await msg_gen.modalities
2595+
self._rt_session.truncate(
26202596
message_id=msg_gen.message_id,
2621-
forwarded_text=forwarded_text,
2622-
interrupted=True,
2597+
modalities=msg_modalities,
2598+
audio_end_ms=int(playback_position * 1000),
2599+
audio_transcript=forwarded_text,
26232600
)
2624-
self._agent._chat_ctx.items.append(msg)
2625-
speech_handle._item_added([msg])
2626-
self._session._conversation_item_added(msg)
2627-
current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text)
26282601

2629-
speech_handle._mark_generation_done()
2630-
await utils.aio.cancel_and_wait(exe_task)
2631-
2632-
for tee in tees:
2633-
await tee.aclose()
2634-
return
2635-
2636-
if len(message_outputs) > 0:
2637-
# there should be only one message
2638-
msg_gen, text_out, _ = message_outputs[0]
2639-
forwarded_text = text_out.text if text_out else ""
2640-
if forwarded_text:
2641-
msg = _create_assistant_message(
2642-
message_id=msg_gen.message_id,
2643-
forwarded_text=forwarded_text,
2644-
interrupted=False,
2645-
)
2646-
self._agent._chat_ctx.items.append(msg)
2647-
speech_handle._item_added([msg])
2648-
self._session._conversation_item_added(msg)
2649-
current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text)
2602+
elif read_transcript_from_tts and text_out and not text_out.text:
2603+
logger.warning(
2604+
"`use_tts_aligned_transcript` is enabled but no agent transcript was returned from tts"
2605+
)
26502606

2651-
elif read_transcript_from_tts and text_out is not None:
2652-
logger.warning(
2653-
"`use_tts_aligned_transcript` is enabled but no agent transcript was returned from tts"
2654-
)
2607+
if msg_gen and forwarded_text:
2608+
msg = _create_assistant_message(
2609+
message_id=msg_gen.message_id,
2610+
forwarded_text=forwarded_text,
2611+
interrupted=speech_handle.interrupted,
2612+
)
2613+
self._agent._chat_ctx.items.append(msg)
2614+
speech_handle._item_added([msg])
2615+
self._session._conversation_item_added(msg)
2616+
current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text)
26552617

26562618
for tee in tees:
26572619
await tee.aclose()
2620+
speech_handle._mark_generation_done()
26582621

2659-
speech_handle._mark_generation_done() # mark the playout done before waiting for the tool execution # noqa: E501
2622+
if speech_handle.interrupted:
2623+
await utils.aio.cancel_and_wait(exe_task)
2624+
return
2625+
2626+
# wait for the tool execution to complete
26602627
tool_output.first_tool_started_fut.add_done_callback(
26612628
lambda _: self._session._update_agent_state("thinking")
26622629
)
@@ -2806,7 +2773,9 @@ def _on_false_interruption() -> None:
28062773
timeout, _on_false_interruption
28072774
)
28082775

2809-
async def _interrupt_paused_speech(self, old_task: asyncio.Task[None] | None = None) -> None:
2776+
async def _cancel_speech_pause(
2777+
self, old_task: asyncio.Task[None] | None = None, *, interrupt: bool = True
2778+
) -> None:
28102779
if old_task is not None:
28112780
await old_task
28122781

@@ -2817,8 +2786,14 @@ async def _interrupt_paused_speech(self, old_task: asyncio.Task[None] | None = N
28172786
if not self._paused_speech:
28182787
return
28192788

2820-
if not self._paused_speech.interrupted and self._paused_speech.allow_interruptions:
2821-
await self._paused_speech.interrupt() # ensure the speech is done
2789+
if (
2790+
interrupt
2791+
and not self._paused_speech.interrupted
2792+
and self._paused_speech.allow_interruptions
2793+
):
2794+
self._paused_speech.interrupt()
2795+
# ensure the generation is done
2796+
await self._paused_speech._wait_for_generation()
28222797
self._paused_speech = None
28232798

28242799
if self._session.options.resume_false_interruption and self._session.output.audio:

livekit-agents/livekit/agents/voice/generation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,10 @@ async def _traceable_fnc_tool(
627627
# TODO(theomonnom): Add the agent handoff inside the current_span
628628
_tool_completed(output)
629629

630-
task = asyncio.create_task(_traceable_fnc_tool(function_callable, fnc_call))
630+
task = asyncio.create_task(
631+
_traceable_fnc_tool(function_callable, fnc_call),
632+
name=f"func_exec_{fnc_call.name}", # task name is used for logging when the task is cancelled
633+
)
631634
_set_activity_task_info(
632635
task, speech_handle=speech_handle, function_call=fnc_call, inline_task=True
633636
)

0 commit comments

Comments
 (0)