Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 77 additions & 46 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
remove_instructions,
update_instructions,
)
from .io import PlaybackFinishedEvent
from .speech_handle import DEFAULT_INPUT_DETAILS, InputDetails, SpeechHandle
from .turn import EndpointingOptions, TurnDetectionMode

Expand Down Expand Up @@ -3012,6 +3013,8 @@ def _on_first_frame(
2. _TextOutput.first_text_fut (None)
"""
nonlocal started_speaking_at, started_forwarding_at
if started_speaking_at is not None:
return
try:
started_speaking_at = fut.result() or time.time()
started_forwarding_at = (
Expand Down Expand Up @@ -3045,15 +3048,10 @@ async def _read_messages(
nonlocal read_transcript_from_tts
assert isinstance(self.llm, llm.RealtimeModel)

forward_tasks: list[asyncio.Task[Any]] = []
msg_tasks: list[asyncio.Task[Any]] = []
try:
async for msg in generation_ev.message_stream:
if len(forward_tasks) > 0:
logger.warning(
"expected to receive only one message generation from the realtime API"
)
break

msg_tasks = []
msg_modalities = await msg.modalities
tts_text_input: AsyncIterable[str] | None = None
if "audio" not in msg_modalities and self.tts:
Expand Down Expand Up @@ -3093,7 +3091,7 @@ async def _read_messages(
tr_text_input = timed_texts
read_transcript_from_tts = True

tasks.append(tts_task)
msg_tasks.append(tts_task)
realtime_audio_result = tts_gen_data.audio_ch
elif "audio" in msg_modalities:
realtime_audio = self._agent.realtime_audio_output_node(
Expand All @@ -3120,7 +3118,7 @@ async def _read_messages(
audio_output=audio_output,
tts_output=realtime_audio_result,
)
forward_tasks.append(forward_task)
msg_tasks.append(forward_task)
audio_out.first_frame_fut.add_done_callback(
partial(_on_first_frame, audio_out=audio_out)
)
Expand All @@ -3134,16 +3132,22 @@ async def _read_messages(
text_output=text_output,
source=tr_node_result,
)
forward_tasks.append(forward_task)
msg_tasks.append(forward_task)

if not audio_out and text_out:
text_out.first_text_fut.add_done_callback(_on_first_frame)

outputs.append((msg, text_out, audio_out))

await asyncio.gather(*forward_tasks)
await asyncio.gather(*msg_tasks)
msg_tasks.clear()

if audio_output is not None and audio_out is not None:
await audio_output.wait_for_playout()
if not audio_out.first_frame_fut.done():
audio_out.first_frame_fut.cancel()
finally:
await utils.aio.cancel_and_wait(*forward_tasks)
await utils.aio.cancel_and_wait(*msg_tasks)

message_outputs: list[
tuple[MessageGeneration, _TextOutput | None, _AudioOutput | None]
Expand Down Expand Up @@ -3241,59 +3245,86 @@ def _create_assistant_message(
msg.metrics = assistant_metrics
return msg

msg_gen, text_out, audio_out = (
message_outputs[0] if len(message_outputs) > 0 else (None, None, None)
) # there should be only one message
partial_idx: int | None = None
playback_ev: PlaybackFinishedEvent | None = None

forwarded_text = text_out.text if text_out else ""
if speech_handle.interrupted:
await utils.aio.cancel_and_wait(*tasks)

if msg_gen and audio_output is not None:
if message_outputs and audio_output is not None:
audio_output.clear_buffer()

playback_ev = await audio_output.wait_for_playout()
playback_position = playback_ev.playback_position
if (
audio_out is not None
and audio_out.first_frame_fut.done()
and not audio_out.first_frame_fut.cancelled()
):
# playback_ev is valid only if the first frame was already played
if playback_ev.synchronized_transcript is not None:
forwarded_text = playback_ev.synchronized_transcript
else:
forwarded_text = ""
playback_position = 0

# truncate server-side message (if supported)
if self.llm.capabilities.message_truncation:
msg_modalities = await msg_gen.modalities
self._rt_session.truncate(
message_id=msg_gen.message_id,
modalities=msg_modalities,
audio_end_ms=int(playback_position * 1000),
audio_transcript=forwarded_text,
)

elif read_transcript_from_tts and text_out and not text_out.text:
if playback_ev.interrupted:
for idx, (_, _, audio_out_i) in enumerate(message_outputs):
if (
audio_out_i is not None
and audio_out_i.first_frame_fut.done()
and not audio_out_i.first_frame_fut.cancelled()
):
partial_idx = idx

if partial_idx is not None and self.llm.capabilities.message_truncation:
msg_gen, text_out, _ = message_outputs[partial_idx]
transcript = playback_ev.synchronized_transcript or (
text_out.text if text_out else ""
)
msg_modalities = await msg_gen.modalities
self._rt_session.truncate(
message_id=msg_gen.message_id,
modalities=msg_modalities,
audio_end_ms=int(playback_ev.playback_position * 1000),
audio_transcript=transcript,
)
elif read_transcript_from_tts and any(
text_out is not None and not text_out.text for _, text_out, _ in message_outputs
):
logger.warning(
"`use_tts_aligned_transcript` is enabled but no agent transcript was returned from tts"
)

if msg_gen and forwarded_text:
trace_text_parts: list[str] = []
for idx, (msg_gen, text_out, audio_out) in enumerate(message_outputs):
if msg_gen is None:
continue

if audio_output is not None:
started_playing = (
audio_out is not None
and audio_out.first_frame_fut.done()
and not audio_out.first_frame_fut.cancelled()
)
if not started_playing:
continue

if idx == partial_idx and playback_ev is not None:
forwarded_text = playback_ev.synchronized_transcript or (
text_out.text if text_out else ""
)
item_interrupted = True
else:
forwarded_text = text_out.text if text_out else ""
item_interrupted = False

if not forwarded_text:
continue

trace_text_parts.append(forwarded_text)
msg = _create_assistant_message(
message_id=msg_gen.message_id,
forwarded_text=forwarded_text,
interrupted=speech_handle.interrupted,
interrupted=item_interrupted,
)
self._agent._chat_ctx._upsert_item(msg)
speech_handle._item_added([msg])
self._session._conversation_item_added(msg)
current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text)

if audio_out is not None and not audio_out.first_frame_fut.done():
audio_out.first_frame_fut.cancel()
if trace_text_parts:
current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, " ".join(trace_text_parts))

for _, _, audio_out in message_outputs:
if audio_out is not None and not audio_out.first_frame_fut.done():
audio_out.first_frame_fut.cancel()

for tee in tees:
await tee.aclose()
Expand Down
15 changes: 9 additions & 6 deletions livekit-agents/livekit/agents/voice/room_io/_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import json
import time

Expand Down Expand Up @@ -100,7 +101,6 @@ async def capture_frame(self, frame: rtc.AudioFrame) -> None:
await super().capture_frame(frame)

if self._flush_task and not self._flush_task.done():
logger.error("capture_frame called while flush is in progress")
await self._flush_task

for f in self._audio_bstream.push(frame.data):
Expand All @@ -117,12 +117,15 @@ def flush(self) -> None:
if not self._pushed_duration:
return

if self._flush_task and not self._flush_task.done():
# shouldn't happen if only one active speech handle at a time
logger.error("flush called while playback is in progress")
self._flush_task.cancel()
prior_task = self._flush_task

async def _serialized_playout() -> None:
if prior_task is not None and not prior_task.done():
with contextlib.suppress(asyncio.CancelledError, Exception):
await prior_task
await self._wait_for_playout()

self._flush_task = asyncio.create_task(self._wait_for_playout())
self._flush_task = asyncio.create_task(_serialized_playout())

def clear_buffer(self) -> None:
self._audio_bstream.clear()
Expand Down
Loading
Loading