Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions livekit-agents/livekit/agents/voice/avatar/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
self._lock = asyncio.Lock()
self._audio_publication: rtc.LocalTrackPublication | None = None
self._video_publication: rtc.LocalTrackPublication | None = None
self._republish_atask: asyncio.Task[None] | None = None
self._lazy_publish = _lazy_publish

# Audio/video sources
Expand Down Expand Up @@ -82,7 +81,6 @@ async def start(self) -> None:
await self._audio_recv.start()
self._audio_recv.on("clear_buffer", self._on_clear_buffer)

self._room.on("reconnected", self._on_reconnected)
self._room.on("connection_state_changed", self._on_connection_state_changed)
if self._room.isconnected():
self._room_connected_fut.set_result(None)
Expand Down Expand Up @@ -185,20 +183,11 @@ async def _handle_clear_buffer(audio_playing: bool) -> None:
task.add_done_callback(self._tasks.discard)
self._audio_playing = False

def _on_reconnected(self) -> None:
if self._lazy_publish and not self._video_publication:
return

if self._republish_atask:
self._republish_atask.cancel()
self._republish_atask = asyncio.create_task(self._publish_track())

def _on_connection_state_changed(self, _: rtc.ConnectionState) -> None:
if self._room.isconnected() and not self._room_connected_fut.done():
self._room_connected_fut.set_result(None)

async def aclose(self) -> None:
self._room.off("reconnected", self._on_reconnected)
self._room.off("connection_state_changed", self._on_connection_state_changed)

await self._audio_recv.aclose()
Expand All @@ -208,9 +197,6 @@ async def aclose(self) -> None:
await aio.cancel_and_wait(self._read_audio_atask)
await aio.cancel_and_wait(*self._tasks)

if self._republish_atask:
await aio.cancel_and_wait(self._republish_atask)

await self._av_sync.aclose()
await self._audio_source.aclose()
await self._video_source.aclose()
37 changes: 14 additions & 23 deletions livekit-agents/livekit/agents/voice/background_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class AudioConfig(NamedTuple):
# AudioSource.clear_queue() would abruptly cut off ambient sounds.
# Instead, we remove the sound from the mixer, and it will get removed 400ms later.
_AUDIO_SOURCE_BUFFER_MS = 400
_TRACK_NAME = "background_audio"


class BackgroundAudioPlayer:
Expand Down Expand Up @@ -109,7 +110,6 @@ def __init__(
self.publication: rtc.LocalTrackPublication | None = None
self._lock = asyncio.Lock()

self._republish_task: asyncio.Task[None] | None = None # republish the task on reconnect
self._mixer_atask: asyncio.Task[None] | None = None

self._play_tasks: list[asyncio.Task[None]] = []
Expand Down Expand Up @@ -262,7 +262,6 @@ async def start(
await self._publish_track()

self._mixer_atask = asyncio.create_task(self._run_mixer_task())
self._room.on("reconnected", self._on_reconnected)

if self._agent_session:
self._agent_session.on("agent_state_changed", self._agent_state_changed)
Expand All @@ -288,9 +287,6 @@ async def aclose(self) -> None:

await cancel_and_wait(*self._play_tasks)

if self._republish_task:
await cancel_and_wait(self._republish_task)

await cancel_and_wait(self._mixer_atask)
self._mixer_atask = None

Expand All @@ -300,18 +296,19 @@ async def aclose(self) -> None:
if self._agent_session:
self._agent_session.off("agent_state_changed", self._agent_state_changed)

self._room.off("reconnected", self._on_reconnected)

with contextlib.suppress(Exception):
if self.publication is not None:
await self._room.local_participant.unpublish_track(self.publication.sid)

def _on_reconnected(self) -> None:
if self._republish_task:
self._republish_task.cancel()

self.publication = None
self._republish_task = asyncio.create_task(self._republish_track_task())
# The cached publication SID may be stale if the SDK
# republished it during a full reconnect; resolve the current
# publication by track name before unpublishing.
current = self._find_publication_by_name(_TRACK_NAME)
if current is not None:
await self._room.local_participant.unpublish_track(current.sid)

def _find_publication_by_name(self, name: str) -> rtc.LocalTrackPublication | None:
for pub in self._room.local_participant.track_publications.values():
if pub.name == name:
return pub
return None

def _agent_state_changed(self, ev: AgentStateChangedEvent) -> None:
if not self._thinking_sound:
Expand Down Expand Up @@ -390,17 +387,11 @@ async def _publish_track(self) -> None:
if self.publication is not None:
return

track = rtc.LocalAudioTrack.create_audio_track("background_audio", self._audio_source)
track = rtc.LocalAudioTrack.create_audio_track(_TRACK_NAME, self._audio_source)
self.publication = await self._room.local_participant.publish_track(
track, self._track_publish_options or rtc.TrackPublishOptions()
)

@log_exceptions(logger=logger)
async def _republish_track_task(self) -> None:
# used to republish the track on agent reconnect
async with self._lock:
await self._publish_track()


class PlayHandle:
def __init__(self) -> None:
Expand Down
11 changes: 0 additions & 11 deletions livekit-agents/livekit/agents/voice/room_io/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def __init__(
sample_rate, num_channels, samples_per_channel=sample_rate // 20, progressive=True
)

# used to republish track on reconnection
self._republish_task: asyncio.Task[None] | None = None
self._flush_task: asyncio.Task[None] | None = None
self._interrupted_event = asyncio.Event()
self._forwarding_task: asyncio.Task[None] | None = None
Expand Down Expand Up @@ -81,12 +79,8 @@ def subscribed(self) -> asyncio.Future[None]:
async def start(self) -> None:
self._forwarding_task = asyncio.create_task(self._forward_audio())
await self._publish_track()
self._room.on("reconnected", self._on_reconnected)

async def aclose(self) -> None:
self._room.off("reconnected", self._on_reconnected)
if self._republish_task:
await utils.aio.cancel_and_wait(self._republish_task)
if self._flush_task:
await utils.aio.cancel_and_wait(self._flush_task)
if self._forwarding_task:
Expand Down Expand Up @@ -198,11 +192,6 @@ async def _forward_audio(self) -> None:
self.on_playback_started(created_at=time.time())
await self._audio_source.capture_frame(frame)

def _on_reconnected(self) -> None:
if self._republish_task:
self._republish_task.cancel()
self._republish_task = asyncio.create_task(self._publish_track())


class _ParticipantLegacyTranscriptionOutput:
def __init__(
Expand Down
206 changes: 206 additions & 0 deletions tests/test_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from __future__ import annotations

import asyncio
import contextlib
import os
import uuid
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
Expand All @@ -18,8 +20,17 @@
wait_for_participant,
wait_for_track_publication,
)
from livekit.agents.voice.room_io._output import _ParticipantAudioOutput

from .lk_server import LK_API_KEY, LK_API_SECRET, LK_URL, livekit_server # noqa: F401
from .utils.audio_test import AudioEnergyMonitor, SineToneSource
from .utils.livekit_test import (
connect_room as _connect_e2e_room,
make_room_name as _e2e_room_name,
simulate_full_reconnect,
simulate_resume,
wait_for_event,
)

TIMEOUT = 5.0

Expand Down Expand Up @@ -208,3 +219,198 @@ async def test_agent_joins(self):
result = await asyncio.wait_for(task, timeout=TIMEOUT)
assert result.identity == "my-agent"
assert result.kind == rtc.ParticipantKind.PARTICIPANT_KIND_AGENT


# -- Reconnect E2E tests --
#
# exercise resume vs full-reconnect behavior end-to-end against a
# real LiveKit server.

_AGENT_IDENTITY = "agent"
_USER_IDENTITY = "user"
_TONE_NAME = "agent_tone"
# Threshold tuned to discriminate the steady sine wave from server-side
# silence/noise. The 440Hz tone at 0.5 amplitude lands well above 0.1.
_AUDIO_RMS_THRESHOLD = 0.05


@contextlib.asynccontextmanager
async def _agent_publishing_tone(agent_room: rtc.Room):
"""Publish a steady sine tone via the production `_ParticipantAudioOutput`
helper, so the test exercises both the SDK and the agents-framework layer."""
tone = SineToneSource(frequency=440.0, amplitude=0.5)
output = _ParticipantAudioOutput(
agent_room,
sample_rate=tone._sample_rate,
num_channels=1,
track_publish_options=rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE),
track_name=_TONE_NAME,
)
# Replace the helper's internal AudioSource with our SineToneSource so the
# published track carries the known signal. _ParticipantAudioOutput
# constructs its own AudioSource in __init__; swap it before .start() so
# the LocalAudioTrack is built against ours.
await output._audio_source.aclose()
output._audio_source = tone.source
await output.start()
await tone.start()
try:
yield output, tone
finally:
await tone.aclose()
await output.aclose()


def _agent_audio_publications(user_room: rtc.Room) -> list[rtc.RemoteTrackPublication]:
"""All audio publications the user sees from the agent."""
agent = user_room.remote_participants.get(_AGENT_IDENTITY)
if agent is None:
return []
return [
pub for pub in agent.track_publications.values() if pub.kind == rtc.TrackKind.KIND_AUDIO
]


async def _await_subscribed_audio(user_room: rtc.Room) -> rtc.RemoteAudioTrack:
pub = await asyncio.wait_for(
wait_for_track_publication(
user_room,
identity=_AGENT_IDENTITY,
kind=rtc.TrackKind.KIND_AUDIO,
wait_for_subscription=True,
),
timeout=15.0,
)
track = pub.track
assert isinstance(track, rtc.RemoteAudioTrack)
return track


async def _wait_back_to_connected(room: rtc.Room, *, timeout: float = 15.0) -> None:
await wait_for_event(
room,
"connection_state_changed",
timeout=timeout,
predicate=lambda state: state == rtc.ConnectionState.CONN_CONNECTED,
)


@pytest.mark.skipif(
not os.environ.get("LIVEKIT_URL"),
reason="LIVEKIT_URL not set; skipping reconnect E2E tests "
"(set LIVEKIT_URL/LIVEKIT_API_KEY/LIVEKIT_API_SECRET to enable)",
)
@pytest.mark.skipif(
not hasattr(rtc, "SimulateScenarioKind"),
reason="livekit-rtc lacks SimulateScenarioKind; run `make link-rtc-local`",
)
class TestReconnect:
async def test_resume_preserves_publication_and_audio(self):
"""Resume must NOT churn the publication set, must NOT fire
`reconnected`, and audio must keep flowing."""
room_name = _e2e_room_name("resume")

async with (
_connect_e2e_room(_USER_IDENTITY, room_name) as user_room,
_connect_e2e_room(_AGENT_IDENTITY, room_name, agent=True) as agent_room,
_agent_publishing_tone(agent_room) as (_output, _tone),
):
await asyncio.wait_for(
wait_for_participant(user_room, identity=_AGENT_IDENTITY),
timeout=10.0,
)
track = await _await_subscribed_audio(user_room)

async with AudioEnergyMonitor.watch(track) as mon:
await mon.wait_for_audio(min_rms=_AUDIO_RMS_THRESHOLD, timeout=5.0)

# Snapshot publication state before the disturbance.
publications_before = _agent_audio_publications(user_room)
assert len(publications_before) == 1, publications_before
sid_before = publications_before[0].sid

# `Reconnected` MUST NOT fire on resume — register a tripwire.
reconnected_fired = asyncio.Event()
agent_room.on("reconnected", lambda: reconnected_fired.set())

await simulate_resume(agent_room)

# Engine cycles through Reconnecting -> Connected; wait for
# the second transition.
await _wait_back_to_connected(agent_room, timeout=15.0)

# Brief grace window for any stray Reconnected dispatch.
await asyncio.sleep(1.0)
assert not reconnected_fired.is_set(), (
"RoomEvent::Reconnected fired on a resume — should only fire on full reconnect"
)

# Publication identity is preserved.
publications_after = _agent_audio_publications(user_room)
assert len(publications_after) == 1, publications_after
assert publications_after[0].sid == sid_before, (
"publication SID changed on resume — engine should have preserved it"
)

# Audio continues to flow uninterrupted.
await mon.assert_audio_continuous(min_rms=_AUDIO_RMS_THRESHOLD, duration=1.5)

async def test_full_reconnect_republishes_once_and_audio_recovers(self):
"""Full reconnect must fire `reconnected` exactly once, end with
exactly one audio publication (the SDK's auto-republish — the
agents framework must not produce a duplicate), and audio must
recover."""
room_name = _e2e_room_name("full")

async with (
_connect_e2e_room(_USER_IDENTITY, room_name) as user_room,
_connect_e2e_room(_AGENT_IDENTITY, room_name, agent=True) as agent_room,
_agent_publishing_tone(agent_room) as (_output, _tone),
):
await asyncio.wait_for(
wait_for_participant(user_room, identity=_AGENT_IDENTITY),
timeout=10.0,
)
track = await _await_subscribed_audio(user_room)

async with AudioEnergyMonitor.watch(track) as mon:
await mon.wait_for_audio(min_rms=_AUDIO_RMS_THRESHOLD, timeout=5.0)

publications_before = _agent_audio_publications(user_room)
assert len(publications_before) == 1

# Bug regression check: should fire exactly once.
reconnect_count = 0

def _count(*_args) -> None:
nonlocal reconnect_count
reconnect_count += 1

agent_room.on("reconnected", _count)

await simulate_full_reconnect(agent_room)
await wait_for_event(agent_room, "reconnected", timeout=20.0)

# After the SDK's auto-republish the user observes
# unpublish -> publish on the agent. Resubscribe.
new_track = await _await_subscribed_audio(user_room)

# Bug regression: must be exactly ONE audio publication.
# Pre-fix, the agents framework's `_on_reconnected` raced
# to publish a second track on top of the SDK's
# auto-republished one.
await asyncio.sleep(0.5) # let any stray duplicate settle
publications_after = _agent_audio_publications(user_room)
assert len(publications_after) == 1, (
f"expected exactly 1 audio publication after full reconnect, "
f"saw {len(publications_after)}: {[p.sid for p in publications_after]}"
)
assert reconnect_count == 1, (
f"reconnected fired {reconnect_count} times; expected exactly 1"
)

async with AudioEnergyMonitor.watch(new_track) as new_mon:
await new_mon.wait_for_audio(min_rms=_AUDIO_RMS_THRESHOLD, timeout=10.0)
await new_mon.assert_audio_continuous(
min_rms=_AUDIO_RMS_THRESHOLD, duration=1.5
)
Loading
Loading