diff --git a/Cargo.toml b/Cargo.toml index 740f255a..b7187b6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ name = "bench_ua" path = "src/bin/bench_ua.rs" [dependencies] +arc-swap = "1.7.1" async-trait = "0.1.89" futures = "0.3.31" rsip = { version = "0.4.0" } diff --git a/src/dialog/dialog_layer.rs b/src/dialog/dialog_layer.rs index cb774528..f51065ab 100644 --- a/src/dialog/dialog_layer.rs +++ b/src/dialog/dialog_layer.rs @@ -559,7 +559,7 @@ impl DialogLayer { let addr = self .endpoint .transport_layer - .get_addrs() + .get_contact_addrs() .first() .ok_or(crate::Error::EndpointError("not sipaddrs".to_string()))? .clone(); diff --git a/src/dialog/tests/test_dialog_states.rs b/src/dialog/tests/test_dialog_states.rs index 2cf87c86..402e827d 100644 --- a/src/dialog/tests/test_dialog_states.rs +++ b/src/dialog/tests/test_dialog_states.rs @@ -70,6 +70,8 @@ pub async fn create_test_endpoint() -> crate::Result Vec { + self.transport_layer.get_contact_addrs() + } + pub fn get_record_route(&self) -> Result { let first_addr = self .transport_layer @@ -758,4 +762,8 @@ impl Endpoint { pub fn get_addrs(&self) -> Vec { self.inner.transport_layer.get_addrs() } + + pub fn get_contact_addrs(&self) -> Vec { + self.inner.transport_layer.get_contact_addrs() + } } diff --git a/src/transport/connection.rs b/src/transport/connection.rs index 1d84ce2b..c6cb9d10 100644 --- a/src/transport/connection.rs +++ b/src/transport/connection.rs @@ -197,6 +197,13 @@ impl SipConnection { SipConnection::WebSocketListener(transport) => transport.get_addr(), } } + pub fn get_contact_addr(&self) -> SipAddr { + if let SipConnection::Udp(transpport) = self { + transpport.get_contact_addr() + } else { + self.get_addr().to_owned() + } + } pub async fn send(&self, msg: rsip::SipMessage, destination: Option<&SipAddr>) -> Result<()> { match self { SipConnection::Channel(transport) => transport.send(msg).await, diff --git a/src/transport/tests/test_udp.rs b/src/transport/tests/test_udp.rs index 039b3264..f0f3d669 100644 --- a/src/transport/tests/test_udp.rs +++ b/src/transport/tests/test_udp.rs @@ -1,8 +1,8 @@ use crate::{ transport::{ connection::{KEEPALIVE_REQUEST, KEEPALIVE_RESPONSE}, - udp::UdpConnection, - TransportEvent, + udp::{UdpConnection, UdpInner}, + SipAddr, TransportEvent, }, Result, }; @@ -85,3 +85,162 @@ async fn test_udp_recv_sip_message() -> Result<()> { }; Ok(()) } + +#[tokio::test] +async fn test_udp_learns_public_addr_from_response_when_external_not_configured() -> Result<()> { + let peer = UdpConnection::create_connection_with_auto_learn_public_addr( + "127.0.0.1:0".parse()?, + None, + None, + true, + ) + .await?; + let remote = UdpConnection::create_connection("127.0.0.1:0".parse()?, None, None).await?; + let (tx, mut rx) = unbounded_channel(); + + let remote_public_port = 62000u16; + let response = format!( + "SIP/2.0 100 Trying\r\n\ +Via: SIP/2.0/UDP 10.0.0.10:5060;branch=z9hG4bK1;rport={};received=198.51.100.10\r\n\ +From: ;tag=1\r\n\ +To: \r\n\ +Call-ID: abc\r\n\ +CSeq: 1 INVITE\r\n\ +Content-Length: 0\r\n\r\n", + remote_public_port + ); + + let peer_addr = peer.get_addr().to_owned(); + tokio::spawn(async move { + sleep(Duration::from_millis(20)).await; + remote + .send_raw(response.as_bytes(), &peer_addr) + .await + .expect("send_raw"); + }); + + select! { + _ = peer.serve_loop(tx) => { + assert!(false, "peer serve_loop exited"); + } + event = rx.recv() => { + match event { + Some(TransportEvent::Incoming(msg, _, _)) => { + assert!(msg.is_response()); + assert_eq!( + peer.get_contact_addr().to_string(), + format!("UDP 198.51.100.10:{}", remote_public_port) + ); + } + _ => assert!(false, "unexpected event"), + } + } + _ = sleep(Duration::from_millis(500)) => { + assert!(false, "timeout waiting"); + } + }; + + Ok(()) +} + +#[tokio::test] +async fn test_udp_contact_prefers_configured_external_addr() -> Result<()> { + let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await?; + let local_addr = socket.local_addr()?; + let peer = UdpConnection::attach_with_auto_learn_public_addr( + UdpInner { + conn: socket, + addr: SipAddr::from(local_addr), + learned_public_addr: arc_swap::ArcSwapOption::empty(), + auto_learn_public_addr: false, + }, + Some("203.0.113.10:5060".parse()?), + None, + true, + ) + .await; + let remote = UdpConnection::create_connection("127.0.0.1:0".parse()?, None, None).await?; + let (tx, mut rx) = unbounded_channel(); + + let response = "SIP/2.0 100 Trying\r\n\ +Via: SIP/2.0/UDP 10.0.0.10:5060;branch=z9hG4bK1;rport=62000;received=198.51.100.10\r\n\ +From: ;tag=1\r\n\ +To: \r\n\ +Call-ID: abc\r\n\ +CSeq: 1 INVITE\r\n\ +Content-Length: 0\r\n\r\n"; + + let peer_local_addr = SipAddr::from(local_addr); + tokio::spawn(async move { + sleep(Duration::from_millis(20)).await; + remote + .send_raw(response.as_bytes(), &peer_local_addr) + .await + .expect("send_raw"); + }); + + select! { + _ = peer.serve_loop(tx) => { + assert!(false, "peer serve_loop exited"); + } + event = rx.recv() => { + match event { + Some(TransportEvent::Incoming(msg, _, _)) => { + assert!(msg.is_response()); + assert_eq!(peer.get_contact_addr().to_string(), "UDP 203.0.113.10:5060"); + } + _ => assert!(false, "unexpected event"), + } + } + _ = sleep(Duration::from_millis(500)) => { + assert!(false, "timeout waiting"); + } + }; + + Ok(()) +} + +#[tokio::test] +async fn test_udp_does_not_learn_public_addr_by_default() -> Result<()> { + let peer = UdpConnection::create_connection("127.0.0.1:0".parse()?, None, None).await?; + let remote = UdpConnection::create_connection("127.0.0.1:0".parse()?, None, None).await?; + let (tx, mut rx) = unbounded_channel(); + + let local_contact_before = peer.get_contact_addr(); + let response = "SIP/2.0 100 Trying\r\n\ +Via: SIP/2.0/UDP 10.0.0.10:5060;branch=z9hG4bK1;rport=62000;received=198.51.100.10\r\n\ +From: ;tag=1\r\n\ +To: \r\n\ +Call-ID: abc\r\n\ +CSeq: 1 INVITE\r\n\ +Content-Length: 0\r\n\r\n"; + + let peer_addr = peer.get_addr().to_owned(); + tokio::spawn(async move { + sleep(Duration::from_millis(20)).await; + remote + .send_raw(response.as_bytes(), &peer_addr) + .await + .expect("send_raw"); + }); + + select! { + _ = peer.serve_loop(tx) => { + assert!(false, "peer serve_loop exited"); + } + event = rx.recv() => { + match event { + Some(TransportEvent::Incoming(msg, _, _)) => { + assert!(msg.is_response()); + assert_eq!(peer.get_contact_addr(), local_contact_before); + } + _ => assert!(false, "unexpected event"), + } + } + _ = sleep(Duration::from_millis(500)) => { + assert!(false, "timeout waiting"); + } + }; + + Ok(()) +} diff --git a/src/transport/transport_layer.rs b/src/transport/transport_layer.rs index 859c49c4..6627d8dd 100644 --- a/src/transport/transport_layer.rs +++ b/src/transport/transport_layer.rs @@ -203,6 +203,16 @@ impl TransportLayer { } } + pub fn get_contact_addrs(&self) -> Vec { + match self.inner.listens.read() { + Ok(listens) => listens.iter().map(|t| t.get_contact_addr()).collect(), + Err(e) => { + warn!(error = ?e, "Failed to read listens"); + Vec::new() + } + } + } + /// Set an async whitelist callback invoked on incoming packets/connections. pub fn set_whitelist(&self, whitelist: T) where @@ -501,10 +511,15 @@ impl Drop for TransportLayer { mod tests { use crate::resolver::SipResolver; use crate::{ - transport::{udp::UdpConnection, SipAddr}, + transport::{ + udp::{UdpConnection, UdpInner}, + SipAddr, + }, Result, }; + use arc_swap::ArcSwapOption; use rsip::Transport; + use std::sync::Arc; #[tokio::test] async fn test_lookup() -> Result<()> { @@ -627,4 +642,44 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_contact_addrs_do_not_change_listener_addrs() -> Result<()> { + let tl = super::TransportLayer::new(tokio_util::sync::CancellationToken::new()); + let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await?; + let local_addr = socket.local_addr()?; + let local_sip_addr = SipAddr { + r#type: Some(rsip::transport::Transport::Udp), + addr: local_addr.into(), + }; + + let learned_public_addr = ArcSwapOption::empty(); + learned_public_addr.store(Some(Arc::new( + "198.51.100.10:62000".parse::()?, + ))); + + let udp_conn = UdpConnection::attach( + UdpInner { + conn: socket, + addr: local_sip_addr.clone(), + learned_public_addr, + auto_learn_public_addr: false, + }, + None, + Some(tl.inner.cancel_token.child_token()), + ) + .await; + + tl.add_transport(udp_conn.into()); + + let addrs = tl.get_addrs(); + assert_eq!(addrs.len(), 1); + assert_eq!(addrs[0], local_sip_addr); + + let contact_addrs = tl.get_contact_addrs(); + assert_eq!(contact_addrs.len(), 1); + assert_eq!(contact_addrs[0].to_string(), "UDP 198.51.100.10:62000"); + + Ok(()) + } } diff --git a/src/transport/udp.rs b/src/transport/udp.rs index d6dc7db5..a3f424ff 100644 --- a/src/transport/udp.rs +++ b/src/transport/udp.rs @@ -7,14 +7,19 @@ use crate::{ }, Result, }; +use arc_swap::ArcSwapOption; use bytes::BytesMut; +use rsip::prelude::HeadersExt; use std::{net::SocketAddr, sync::Arc}; use tokio::net::UdpSocket; use tokio_util::sync::CancellationToken; use tracing::{debug, warn}; + pub struct UdpInner { pub conn: UdpSocket, pub addr: SipAddr, + pub learned_public_addr: ArcSwapOption, + pub auto_learn_public_addr: bool, } #[derive(Clone)] @@ -30,6 +35,16 @@ impl UdpConnection { external: Option, cancel_token: Option, ) -> Self { + Self::attach_with_auto_learn_public_addr(inner, external, cancel_token, false).await + } + + pub async fn attach_with_auto_learn_public_addr( + mut inner: UdpInner, + external: Option, + cancel_token: Option, + auto_learn_public_addr: bool, + ) -> Self { + inner.auto_learn_public_addr = auto_learn_public_addr; UdpConnection { external: external.map(|addr| SipAddr { r#type: Some(rsip::transport::Transport::Udp), @@ -44,6 +59,16 @@ impl UdpConnection { local: SocketAddr, external: Option, cancel_token: Option, + ) -> Result { + Self::create_connection_with_auto_learn_public_addr(local, external, cancel_token, false) + .await + } + + pub async fn create_connection_with_auto_learn_public_addr( + local: SocketAddr, + external: Option, + cancel_token: Option, + auto_learn_public_addr: bool, ) -> Result { let conn = UdpSocket::bind(local).await?; @@ -57,7 +82,12 @@ impl UdpConnection { r#type: Some(rsip::transport::Transport::Udp), addr: addr.into(), }), - inner: Arc::new(UdpInner { addr, conn }), + inner: Arc::new(UdpInner { + addr, + conn, + learned_public_addr: ArcSwapOption::empty(), + auto_learn_public_addr, + }), cancel_token, }; debug!(local = %t, ?external, "created UDP connection"); @@ -164,6 +194,10 @@ impl UdpConnection { } }; + if self.external.is_none() && self.inner.auto_learn_public_addr { + self.learn_public_addr_from_message(&msg); + } + debug!(len, src=%addr, dest=%self.get_addr(), raw_message=undecoded, "udp received"); sender.send(TransportEvent::Incoming( @@ -229,9 +263,53 @@ impl UdpConnection { &self.inner.addr } } + + pub fn get_contact_addr(&self) -> SipAddr { + if let Some(external) = &self.external { + external.clone() + } else { + self.inner + .learned_public_addr + .load_full() + .map(|addr| SipAddr { + r#type: Some(rsip::transport::Transport::Udp), + addr: (*addr).into(), + }) + .unwrap_or_else(|| self.inner.addr.clone()) + } + } pub fn cancel_token(&self) -> Option { self.cancel_token.clone() } + + fn learn_public_addr_from_message(&self, msg: &rsip::SipMessage) { + let response = match msg { + rsip::SipMessage::Response(resp) => resp, + rsip::SipMessage::Request(_) => return, + }; + + let via = match response.via_header() { + Ok(via) => via, + Err(_) => return, + }; + + let target = match SipConnection::parse_target_from_via(via) { + Ok((transport, host_with_port)) if transport == rsip::transport::Transport::Udp => { + match host_with_port.try_into() { + Ok(addr) => addr, + Err(_) => return, + } + } + _ => return, + }; + + let current = self.inner.learned_public_addr.load(); + let changed = current.as_deref() != Some(&target); + if changed { + debug!(addr = %target, "udp learned public address"); + self.inner.learned_public_addr.store(Some(Arc::new(target))); + } + } } impl std::fmt::Display for UdpConnection {