diff --git a/interceptor/src/lib.rs b/interceptor/src/lib.rs index 98008d405..e843ed656 100644 --- a/interceptor/src/lib.rs +++ b/interceptor/src/lib.rs @@ -24,6 +24,9 @@ pub mod twcc; pub use error::Error; +/// Attribute indicating the stream is probing incoming packets. +pub const ATTR_READ_PROBE: usize = 2295978936; + /// Attributes are a generic key/value store used by interceptors pub type Attributes = HashMap; diff --git a/interceptor/src/report/receiver/mod.rs b/interceptor/src/report/receiver/mod.rs index ef3381949..b5d8849f4 100644 --- a/interceptor/src/report/receiver/mod.rs +++ b/interceptor/src/report/receiver/mod.rs @@ -110,11 +110,11 @@ impl ReceiverReport { m.values().cloned().collect() }; for stream in streams { - let pkt = stream.generate_report(now); - - let a = Attributes::new(); - if let Err(err) = rtcp_writer.write(&[Box::new(pkt)], &a).await{ - log::warn!("failed sending: {}", err); + if let Some(pkt) = stream.generate_report(now) { + let a = Attributes::new(); + if let Err(err) = rtcp_writer.write(&[Box::new(pkt)], &a).await{ + log::warn!("failed sending: {}", err); + } } } } @@ -186,11 +186,17 @@ impl Interceptor for ReceiverReport { info: &StreamInfo, reader: Arc, ) -> Arc { + let wait_for_probe = info + .attributes + .get(&crate::ATTR_READ_PROBE) + .is_some_and(|v| *v != 0); + let stream = Arc::new(ReceiverStream::new( info.ssrc, info.clock_rate, reader, self.internal.now.clone(), + wait_for_probe, )); { let mut streams = self.internal.streams.lock().await; diff --git a/interceptor/src/report/receiver/receiver_stream.rs b/interceptor/src/report/receiver/receiver_stream.rs index d170922e8..652102123 100644 --- a/interceptor/src/report/receiver/receiver_stream.rs +++ b/interceptor/src/report/receiver/receiver_stream.rs @@ -13,6 +13,7 @@ struct ReceiverStreamInternal { packets: Vec, started: bool, + wait_for_probe: bool, seq_num_cycles: u16, last_seq_num: i32, last_report_seq_num: i32, @@ -40,7 +41,7 @@ impl ReceiverStreamInternal { (self.packets[pos / 64] & (1 << (pos % 64))) != 0 } - fn process_rtp(&mut self, now: SystemTime, pkt: &rtp::packet::Packet) { + fn process_rtp(&mut self, now: SystemTime, pkt: &rtp::packet::Packet, is_probe: bool) { if !self.started { // first frame self.started = true; @@ -79,6 +80,7 @@ impl ReceiverStreamInternal { self.last_rtp_time_rtp = pkt.header.timestamp; self.last_rtp_time_time = now; + self.wait_for_probe &= is_probe; } fn process_sender_report(&mut self, now: SystemTime, sr: &rtcp::sender_report::SenderReport) { @@ -158,6 +160,7 @@ impl ReceiverStream { clock_rate: u32, reader: Arc, now: Option, + wait_for_probe: bool, ) -> Self { let receiver_ssrc = rand::random::(); ReceiverStream { @@ -171,6 +174,7 @@ impl ReceiverStream { packets: vec![0u64; 128], started: false, + wait_for_probe, seq_num_cycles: 0, last_seq_num: 0, last_report_seq_num: 0, @@ -184,9 +188,9 @@ impl ReceiverStream { } } - pub(crate) fn process_rtp(&self, now: SystemTime, pkt: &rtp::packet::Packet) { + pub(crate) fn process_rtp(&self, now: SystemTime, pkt: &rtp::packet::Packet, is_probe: bool) { let mut internal = self.internal.lock(); - internal.process_rtp(now, pkt); + internal.process_rtp(now, pkt, is_probe); } pub(crate) fn process_sender_report( @@ -198,9 +202,17 @@ impl ReceiverStream { internal.process_sender_report(now, sr); } - pub(crate) fn generate_report(&self, now: SystemTime) -> rtcp::receiver_report::ReceiverReport { + pub(crate) fn generate_report( + &self, + now: SystemTime, + ) -> Option { let mut internal = self.internal.lock(); - internal.generate_report(now) + + if internal.wait_for_probe { + return None; + } + + Some(internal.generate_report(now)) } } @@ -213,6 +225,8 @@ impl RTPReader for ReceiverStream { buf: &mut [u8], a: &Attributes, ) -> Result<(rtp::packet::Packet, Attributes)> { + let is_probe = a.get(&crate::ATTR_READ_PROBE).is_some_and(|v| *v != 0); + let (pkt, attr) = self.parent_rtp_reader.read(buf, a).await?; let now = if let Some(f) = &self.now { @@ -220,7 +234,7 @@ impl RTPReader for ReceiverStream { } else { SystemTime::now() }; - self.process_rtp(now, &pkt); + self.process_rtp(now, &pkt, is_probe); Ok((pkt, attr)) } diff --git a/interceptor/src/report/receiver/receiver_test.rs b/interceptor/src/report/receiver/receiver_test.rs index 77fa4a929..ef259501d 100644 --- a/interceptor/src/report/receiver/receiver_test.rs +++ b/interceptor/src/report/receiver/receiver_test.rs @@ -58,6 +58,73 @@ async fn test_receiver_interceptor_before_any_packet() -> Result<()> { Ok(()) } +#[tokio::test(start_paused = true)] +async fn test_receiver_interceptor_read_probe() -> Result<()> { + let mt = Arc::new(MockTime::default()); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; + + let icpr: Arc = ReceiverReport::builder() + .with_interval(Duration::from_millis(50)) + .with_now_fn(time_gen) + .build("")?; + + let stream = MockStream::new( + &StreamInfo { + ssrc: 123456, + clock_rate: 90000, + attributes: [(crate::ATTR_READ_PROBE, 1)].into_iter().collect(), + ..Default::default() + }, + icpr, + ) + .await; + + // no report initially + tokio::time::timeout(Duration::from_millis(60), stream.written_rtcp()) + .await + .expect_err("expected no report"); + + stream + .receive_rtp(rtp::packet::Packet { + header: rtp::header::Header { + sequence_number: 7, + ..Default::default() + }, + ..Default::default() + }) + .await; + + let pkts = stream.written_rtcp().await.unwrap(); + assert_eq!(pkts.len(), 1); + if let Some(rr) = pkts[0] + .as_any() + .downcast_ref::() + { + assert_eq!(rr.reports.len(), 1); + assert_eq!( + rr.reports[0], + rtcp::reception_report::ReceptionReport { + ssrc: 123456, + last_sequence_number: 7, + last_sender_report: 0, + fraction_lost: 0, + total_lost: 0, + delay: 0, + jitter: 0, + } + ) + } else { + panic!(); + } + + stream.close().await?; + + Ok(()) +} + #[tokio::test] async fn test_receiver_interceptor_after_rtp_packets() -> Result<()> { let mt = Arc::new(MockTime::default()); diff --git a/webrtc/src/peer_connection/peer_connection_internal.rs b/webrtc/src/peer_connection/peer_connection_internal.rs index 2d8c1d9d6..6f69a0e32 100644 --- a/webrtc/src/peer_connection/peer_connection_internal.rs +++ b/webrtc/src/peer_connection/peer_connection_internal.rs @@ -1072,7 +1072,7 @@ impl PeerConnectionInternal { None => return Err(Error::ErrInterceptorNotBind), }; - let stream_info = create_stream_info( + let mut stream_info = create_stream_info( "".to_owned(), ssrc, params.codecs[0].payload_type, @@ -1080,15 +1080,22 @@ impl PeerConnectionInternal { ¶ms.header_extensions, None, ); + + // indicate this stream starts with probing + stream_info + .attributes + .insert(interceptor::ATTR_READ_PROBE, 1); + let (rtp_read_stream, rtp_interceptor, rtcp_read_stream, rtcp_interceptor) = self .dtls_transport .streams_for_ssrc(ssrc, &stream_info, &icpr) .await?; - let a = Attributes::new(); for _ in 0..=SIMULCAST_PROBE_COUNT { if mid.is_empty() || (rid.is_empty() && rsid.is_empty()) { - let (pkt, _) = rtp_interceptor.read(&mut buf, &a).await?; + let (pkt, a) = rtp_interceptor + .read(&mut buf, &stream_info.attributes) + .await?; let (m, r, rs, _) = handle_unknown_rtp_packet( &pkt.header, mid_extension_id as u8,