diff --git a/rs/hang/examples/video.rs b/rs/hang/examples/video.rs index 920435deb..eecabbedc 100644 --- a/rs/hang/examples/video.rs +++ b/rs/hang/examples/video.rs @@ -93,7 +93,7 @@ fn create_track(broadcast: &mut moq_lite::BroadcastProducer) -> hang::TrackProdu .produce(); // Publish the catalog track to the broadcast. - broadcast.insert_track(catalog.track.consume()); + broadcast.insert_track(catalog.track.clone()); // Actually create the media track now. let track = broadcast.create_track(video_track); diff --git a/rs/hang/src/model/broadcast.rs b/rs/hang/src/model/broadcast.rs index c2407587d..86838cea5 100644 --- a/rs/hang/src/model/broadcast.rs +++ b/rs/hang/src/model/broadcast.rs @@ -19,7 +19,7 @@ pub struct BroadcastProducer { impl BroadcastProducer { pub fn new(mut inner: moq_lite::BroadcastProducer) -> Self { let catalog = Catalog::default().produce(); - inner.insert_track(catalog.track.consume()); + inner.insert_track(catalog.track.clone()); Self { inner, diff --git a/rs/moq-lite/src/ietf/subscriber.rs b/rs/moq-lite/src/ietf/subscriber.rs index 3d729a647..5354aa169 100644 --- a/rs/moq-lite/src/ietf/subscriber.rs +++ b/rs/moq-lite/src/ietf/subscriber.rs @@ -466,7 +466,7 @@ impl Subscriber { // NOTE: This is debated in the IETF draft, but is significantly easier to implement. let mut broadcast = self.start_announce(msg.track_namespace.to_owned())?; - let exists = broadcast.insert_track(track.consume()); + let exists = broadcast.insert_track(track); if exists { tracing::warn!(track = %msg.track_name, "track already exists, replacing it"); } diff --git a/rs/moq-lite/src/lite/publisher.rs b/rs/moq-lite/src/lite/publisher.rs index c90ef99f5..4f0c90c65 100644 --- a/rs/moq-lite/src/lite/publisher.rs +++ b/rs/moq-lite/src/lite/publisher.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use futures::{FutureExt, StreamExt, stream::FuturesUnordered}; use web_async::FuturesExt; use crate::{ @@ -215,49 +216,21 @@ impl Publisher { priority: PriorityQueue, version: Version, ) -> Result<(), Error> { - // TODO use a BTreeMap serve the latest N groups by sequence. - // Until then, we'll implement N=2 manually. - // Also, this is more complicated because we can't use tokio because of WASM. - // We need to drop futures in order to cancel them and keep polling them with select! - let mut old_group = None; - let mut new_group = None; - - // Annoying that we can't use a tuple here as we need the compiler to infer the type. - // Otherwise we'd have to pick Send or !Send... - let mut old_sequence = None; - let mut new_sequence = None; - - // Keep reading groups from the track, some of which may arrive out of order. + let mut tasks = FuturesUnordered::new(); + loop { let group = tokio::select! { - biased; + // Poll all active group futures; never matches but keeps them running. + true = async { + while tasks.next().await.is_some() {} + false + } => unreachable!(), Some(group) = track.next_group().transpose() => group, - Some(_) = async { Some(old_group.as_mut()?.await) } => { - old_group = None; - old_sequence = None; - continue; - }, - Some(_) = async { Some(new_group.as_mut()?.await) } => { - new_group = old_group; - new_sequence = old_sequence; - old_group = None; - old_sequence = None; - continue; - }, else => return Ok(()), }?; let sequence = group.info.sequence; - let latest = new_sequence.as_ref().unwrap_or(&0); - - tracing::debug!(subscribe = %subscribe.id, track = %track.info.name, sequence, latest, "serving group"); - - // If this group is older than the oldest group we're serving, skip it. - // We always serve at most two groups, but maybe we should serve only sequence >= MAX-1. - if sequence < *old_sequence.as_ref().unwrap_or(&0) { - tracing::debug!(subscribe = %subscribe.id, track = %track.info.name, old = %sequence, %latest, "skipping group"); - continue; - } + tracing::debug!(subscribe = %subscribe.id, track = %track.info.name, sequence, "serving group"); let msg = lite::Group { subscribe: subscribe.id, @@ -265,29 +238,7 @@ impl Publisher { }; let priority = priority.insert(track.info.priority, sequence); - - // Spawn a task to serve this group, ignoring any errors because they don't really matter. - // TODO add some logging at least. - let handle = Box::pin(Self::serve_group(session.clone(), msg, priority, group, version)); - - // Terminate the old group if it's still running. - if let Some(old_sequence) = old_sequence.take() { - tracing::debug!(subscribe = %subscribe.id, track = %track.info.name, old = %old_sequence, %latest, "aborting group"); - old_group.take(); // Drop the future to cancel it. - } - - assert!(old_group.is_none()); - - if sequence >= *latest { - old_group = new_group; - old_sequence = new_sequence; - - new_group = Some(handle); - new_sequence = Some(sequence); - } else { - old_group = Some(handle); - old_sequence = Some(sequence); - } + tasks.push(Self::serve_group(session.clone(), msg, priority, group, version).map(|_| ())); } } diff --git a/rs/moq-lite/src/model/broadcast.rs b/rs/moq-lite/src/model/broadcast.rs index 15affe522..d57a19031 100644 --- a/rs/moq-lite/src/model/broadcast.rs +++ b/rs/moq-lite/src/model/broadcast.rs @@ -16,11 +16,11 @@ use super::Track; struct State { // When explicitly publishing, we hold a reference to the consumer. // This prevents the track from being marked as "unused". - published: HashMap, + consumers: HashMap, // When requesting, we hold a reference to the producer for dynamic tracks. // The track will be marked as "unused" when the last consumer is dropped. - requested: HashMap, + producers: HashMap, } /// A collection of media tracks that can be published and subscribed to. @@ -58,8 +58,8 @@ impl BroadcastProducer { pub fn new() -> Self { Self { state: Lock::new(State { - published: HashMap::new(), - requested: HashMap::new(), + consumers: HashMap::new(), + producers: HashMap::new(), }), closed: Default::default(), requested: async_channel::unbounded(), @@ -74,24 +74,24 @@ impl BroadcastProducer { /// Produce a new track and insert it into the broadcast. pub fn create_track(&mut self, track: Track) -> TrackProducer { - let track = track.produce(); - self.insert_track(track.consume()); + let track = TrackProducer::new(track); + self.insert_track(track.clone()); track } /// Insert a track into the lookup, returning true if it was unique. - pub fn insert_track(&mut self, track: TrackConsumer) -> bool { + /// + /// NOTE: You probably want to [TrackProducer::clone] to keep publishing to the track. + pub fn insert_track(&mut self, track: TrackProducer) -> bool { let mut state = self.state.lock(); - let unique = state.published.insert(track.info.name.clone(), track.clone()).is_none(); - let removed = state.requested.remove(&track.info.name).is_some(); - - unique && !removed + state.consumers.insert(track.info.name.clone(), track.consume()); + state.producers.insert(track.info.name.clone(), track).is_none() } /// Remove a track from the lookup. pub fn remove_track(&mut self, name: &str) -> bool { let mut state = self.state.lock(); - state.published.remove(name).is_some() || state.requested.remove(name).is_some() + state.consumers.remove(name).is_some() || state.producers.remove(name).is_some() } pub fn consume(&self) -> BroadcastConsumer { @@ -150,8 +150,8 @@ impl Drop for BroadcastProducer { let mut state = self.state.lock(); // Cleanup any published tracks. - state.published.clear(); - state.requested.clear(); + state.consumers.clear(); + state.producers.clear(); } } @@ -192,13 +192,7 @@ impl BroadcastConsumer { pub fn subscribe_track(&self, track: &Track) -> TrackConsumer { let mut state = self.state.lock(); - // Return any explictly published track. - if let Some(consumer) = state.published.get(&track.name).cloned() { - return consumer; - } - - // Return any requested tracks. - if let Some(producer) = state.requested.get(&track.name) { + if let Some(producer) = state.producers.get(&track.name) { return producer.consume(); } @@ -219,13 +213,13 @@ impl BroadcastConsumer { } // Insert the producer into the lookup so we will deduplicate requests. - state.requested.insert(producer.info.name.clone(), producer.clone()); + state.producers.insert(producer.info.name.clone(), producer.clone()); // Remove the track from the lookup when it's unused. let state = self.state.clone(); web_async::spawn(async move { producer.unused().await; - state.lock().requested.remove(&producer.info.name); + state.lock().producers.remove(&producer.info.name); }); consumer @@ -268,19 +262,19 @@ mod test { let mut track1 = Track::new("track1").produce(); // Make sure we can insert before a consumer is created. - producer.insert_track(track1.consume()); + producer.insert_track(track1.clone()); track1.append_group(); let consumer = producer.consume(); - let mut track1_sub = consumer.subscribe_track(&track1.info); + let mut track1_sub = consumer.subscribe_track(&Track::new("track1")); track1_sub.assert_group(); let mut track2 = Track::new("track2").produce(); - producer.insert_track(track2.consume()); + producer.insert_track(track2.clone()); let consumer2 = producer.consume(); - let mut track2_consumer = consumer2.subscribe_track(&track2.info); + let mut track2_consumer = consumer2.subscribe_track(&Track::new("track2")); track2_consumer.assert_no_group(); track2.append_group(); @@ -330,9 +324,8 @@ mod test { consumer.assert_not_closed(); // Create a new track and insert it into the broadcast. - let mut track1 = Track::new("track1").produce(); + let mut track1 = producer.create_track(Track::new("track1")); track1.append_group(); - producer.insert_track(track1.consume()); let mut track1c = consumer.subscribe_track(&track1.info); let track2 = consumer.subscribe_track(&Track::new("track2")); @@ -405,10 +398,9 @@ mod test { #[tokio::test] async fn requested_unused() { let mut broadcast = Broadcast::produce(); - let consumer = broadcast.consume(); // Subscribe to a track that doesn't exist - this creates a request - let consumer1 = consumer.subscribe_track(&Track::new("unknown_track")); + let consumer1 = broadcast.consume().subscribe_track(&Track::new("unknown_track")); // Get the requested track producer let producer1 = broadcast.assert_request(); @@ -420,7 +412,7 @@ mod test { ); // Making a new consumer will keep the producer alive - let consumer2 = consumer.subscribe_track(&Track::new("unknown_track")); + let consumer2 = broadcast.consume().subscribe_track(&Track::new("unknown_track")); consumer2.assert_is_clone(&consumer1); // Drop the consumer subscription @@ -447,7 +439,7 @@ mod test { tokio::time::sleep(std::time::Duration::from_millis(1)).await; // Now the cleanup task should have run and we can subscribe again to the unknown track. - let consumer3 = consumer.subscribe_track(&Track::new("unknown_track")); + let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track")); let producer2 = broadcast.assert_request(); // Drop the consumer, now the producer should be unused diff --git a/rs/moq-lite/src/model/group.rs b/rs/moq-lite/src/model/group.rs index 43ab738a7..de82266c0 100644 --- a/rs/moq-lite/src/model/group.rs +++ b/rs/moq-lite/src/model/group.rs @@ -215,4 +215,24 @@ impl GroupConsumer { } } } + + pub async fn get_frame(&self, index: usize) -> Result> { + let mut state = self.state.clone(); + let Ok(state) = state + .wait_for(|state| index < state.frames.len() || state.closed.is_some()) + .await + else { + return Err(Error::Cancel); + }; + + if let Some(frame) = state.frames.get(index).cloned() { + return Ok(Some(frame)); + } + + match &state.closed { + Some(Ok(_)) => Ok(None), + Some(Err(err)) => Err(err.clone()), + _ => unreachable!(), + } + } } diff --git a/rs/moq-lite/src/model/track.rs b/rs/moq-lite/src/model/track.rs index 06f69508e..7f0d715e4 100644 --- a/rs/moq-lite/src/model/track.rs +++ b/rs/moq-lite/src/model/track.rs @@ -18,7 +18,9 @@ use crate::{Error, Result}; use super::{Group, GroupConsumer, GroupProducer}; -use std::{cmp::Ordering, future::Future}; +use std::{collections::VecDeque, future::Future}; + +const MAX_CACHE: std::time::Duration = std::time::Duration::from_secs(30); #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -42,8 +44,28 @@ impl Track { #[derive(Default)] struct TrackState { - latest: Option, + groups: VecDeque<(tokio::time::Instant, GroupConsumer)>, closed: Option>, + offset: usize, + + max_sequence: Option, + + // The largest sequence number that has been dropped. + drop_sequence: Option, +} + +impl TrackState { + fn trim(&mut self, now: tokio::time::Instant) { + while let Some((timestamp, _)) = self.groups.front() { + if now.saturating_duration_since(*timestamp) > MAX_CACHE { + let (_, group) = self.groups.pop_front().unwrap(); + self.drop_sequence = Some(self.drop_sequence.unwrap_or(0).max(group.info.sequence)); + self.offset += 1; + } else { + break; + } + } + } } /// A producer for a track, used to create new groups. @@ -65,16 +87,10 @@ impl TrackProducer { pub fn insert_group(&mut self, group: GroupConsumer) -> bool { self.state.send_if_modified(|state| { assert!(state.closed.is_none()); - - if let Some(latest) = &state.latest { - match group.info.cmp(&latest.info) { - Ordering::Less => return false, - Ordering::Equal => return false, - Ordering::Greater => (), - } - } - - state.latest = Some(group.clone()); + let now = tokio::time::Instant::now(); + state.trim(now); + state.groups.push_back((now, group.clone())); + state.max_sequence = Some(state.max_sequence.unwrap_or(0).max(group.info.sequence)); true }) } @@ -94,11 +110,15 @@ impl TrackProducer { self.state.send_if_modified(|state| { assert!(state.closed.is_none()); - let sequence = state.latest.as_ref().map_or(0, |group| group.info.sequence + 1); + let now = tokio::time::Instant::now(); + state.trim(now); + + let sequence = state.max_sequence.map_or(0, |sequence| sequence + 1); let group = Group { sequence }.produce(); - state.latest = Some(group.consume()); - producer = Some(group); + state.groups.push_back((now, group.consume())); + state.max_sequence = Some(sequence); + producer = Some(group); true }); @@ -122,10 +142,12 @@ impl TrackProducer { /// Create a new consumer for the track. pub fn consume(&self) -> TrackConsumer { + let state = self.state.borrow(); TrackConsumer { info: self.info.clone(), state: self.state.subscribe(), - prev: None, + // Start at the latest group + index: state.offset + state.groups.len().saturating_sub(1), } } @@ -154,7 +176,7 @@ impl From for TrackProducer { pub struct TrackConsumer { pub info: Track, state: watch::Receiver, - prev: Option, // The previous sequence number + index: usize, } impl TrackConsumer { @@ -166,24 +188,61 @@ impl TrackConsumer { let Ok(state) = self .state .wait_for(|state| { - state.latest.as_ref().map(|group| group.info.sequence) > self.prev || state.closed.is_some() + let index = self.index.saturating_sub(state.offset); + state.groups.get(index).is_some() || state.closed.is_some() }) .await else { return Err(Error::Cancel); }; + let index = self.index.saturating_sub(state.offset); + if let Some(group) = state.groups.get(index) { + self.index = state.offset + index + 1; + return Ok(Some(group.1.clone())); + } + match &state.closed { - Some(Ok(_)) => return Ok(None), - Some(Err(err)) => return Err(err.clone()), - _ => {} + Some(Ok(_)) => Ok(None), + Some(Err(err)) => Err(err.clone()), + _ => unreachable!(), } + } + + /// Block until the group is available. + /// + /// NOTE: This can block indefinitely if the requested group is dropped. + pub async fn get_group(&self, sequence: u64) -> Result> { + let mut state = self.state.clone(); - // If there's a new latest group, return it. - let group = state.latest.clone().unwrap(); - self.prev = Some(group.info.sequence); + let Ok(state) = state + .wait_for(|state| { + if state.closed.is_some() { + return true; + } - Ok(Some(group)) + if let Some(drop_sequence) = state.drop_sequence + && drop_sequence >= sequence + { + return true; + } + + state.groups.iter().any(|(_, group)| group.info.sequence == sequence) + }) + .await + else { + return Err(Error::Cancel); + }; + + if let Some((_, group)) = state.groups.iter().find(|(_, group)| group.info.sequence == sequence) { + return Ok(Some(group.clone())); + } + + match &state.closed { + Some(Ok(_)) => Ok(None), // end of stream + Some(Err(err)) => Err(err.clone()), + None => Ok(None), // Dropped + } } /// Block until the track is closed. diff --git a/rs/moq-relay/src/auth.rs b/rs/moq-relay/src/auth.rs index b011c97a8..cb4c7f701 100644 --- a/rs/moq-relay/src/auth.rs +++ b/rs/moq-relay/src/auth.rs @@ -68,7 +68,7 @@ pub struct AuthToken { pub cluster: bool, } -const REFRESH_ERROR_INTERVAL: Duration = Duration::from_mins(5); +const REFRESH_ERROR_INTERVAL: Duration = Duration::from_secs(300); #[derive(Clone)] pub struct Auth { diff --git a/rs/moq-relay/src/web.rs b/rs/moq-relay/src/web.rs index 620fa9bbb..81534f97f 100644 --- a/rs/moq-relay/src/web.rs +++ b/rs/moq-relay/src/web.rs @@ -22,18 +22,12 @@ use axum::{ use bytes::Bytes; use clap::Parser; use moq_lite::{OriginConsumer, OriginProducer}; -use serde::{Deserialize, Serialize}; use std::future::Future; use tower_http::cors::{Any, CorsLayer}; use crate::{Auth, Cluster}; -#[derive(Debug, Deserialize)] -struct Params { - jwt: Option, -} - -#[derive(Parser, Clone, Debug, Deserialize, Serialize, Default)] +#[derive(Parser, Clone, Debug, serde::Deserialize, serde::Serialize, Default)] #[serde(deny_unknown_fields, default)] pub struct WebConfig { #[command(flatten)] @@ -166,10 +160,70 @@ async fn serve_fingerprint(State(state): State>) -> String { .clone() } +#[derive(Debug, serde::Deserialize)] +struct AuthParams { + jwt: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct FetchParams { + #[serde(flatten)] + auth: AuthParams, + + #[serde(default)] + group: FetchGroup, + + #[serde(default)] + frame: FetchFrame, +} + +#[derive(Debug, Default)] +enum FetchGroup { + // Return the group at the given sequence number. + Num(u64), + + // Return the latest group. + #[default] + Latest, +} + +impl<'de> serde::Deserialize<'de> for FetchGroup { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + if let Ok(num) = s.parse::() { + Ok(FetchGroup::Num(num)) + } else if s == "latest" { + Ok(FetchGroup::Latest) + } else { + Err(serde::de::Error::custom(format!("invalid group value: {s}"))) + } + } +} + +#[derive(Debug, Default)] +enum FetchFrame { + Num(usize), + #[default] + Chunked, +} + +impl<'de> serde::Deserialize<'de> for FetchFrame { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + if let Ok(num) = s.parse::() { + Ok(FetchFrame::Num(num)) + } else if s == "chunked" { + Ok(FetchFrame::Chunked) + } else { + Err(serde::de::Error::custom(format!("invalid frame value: {s}"))) + } + } +} + async fn serve_ws( ws: WebSocketUpgrade, Path(path): Path, - Query(params): Query, + Query(params): Query, State(state): State>, ) -> axum::response::Result { let ws = ws.protocols(["webtransport"]); @@ -229,7 +283,7 @@ where /// Serve the announced broadcasts for a given prefix. async fn serve_announced( path: Option>, - Query(params): Query, + Query(params): Query, State(state): State>, ) -> axum::response::Result { let prefix = match path { @@ -253,10 +307,10 @@ async fn serve_announced( Ok(broadcasts.iter().map(|p| p.to_string()).collect::>().join("\n")) } -/// Serve the latest group for a given track +/// Serve the given group for a given track async fn serve_fetch( Path(path): Path, - Query(params): Query, + Query(params): Query, State(state): State>, ) -> axum::response::Result { // The path containts a broadcast/track @@ -269,7 +323,7 @@ async fn serve_fetch( } let broadcast = path.join("/"); - let token = state.auth.verify(&broadcast, params.jwt.as_deref())?; + let token = state.auth.verify(&broadcast, params.auth.jwt.as_deref())?; let Some(origin) = state.cluster.subscriber(&token) else { return Err(StatusCode::UNAUTHORIZED.into()); @@ -286,39 +340,76 @@ async fn serve_fetch( let broadcast = origin.consume_broadcast("").ok_or(StatusCode::NOT_FOUND)?; let mut track = broadcast.subscribe_track(&track); - let Ok(group) = track.next_group().await else { - return Err(StatusCode::INTERNAL_SERVER_ERROR.into()); - }; - let Some(group) = group else { - return Err(StatusCode::NOT_FOUND.into()); - }; + let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_secs(30); + + let result = tokio::time::timeout_at(deadline, async { + let group = match params.group { + FetchGroup::Latest => track.next_group().await, + FetchGroup::Num(sequence) => track.get_group(sequence).await, + }; + + let group = match group { + Ok(Some(group)) => group, + Ok(None) => return Err(StatusCode::NOT_FOUND), + Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR), + }; + + tracing::info!(track = %track.info.name, group = %group.info.sequence, "serving group"); + + match params.frame { + FetchFrame::Num(index) => match group.get_frame(index).await { + Ok(Some(frame)) => Ok(ServeGroup { + group: None, + frame: Some(frame), + deadline, + }), + Ok(None) => Err(StatusCode::NOT_FOUND), + Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), + }, + FetchFrame::Chunked => Ok(ServeGroup { + group: Some(group), + frame: None, + deadline, + }), + } + }) + .await; - Ok(ServeGroup::new(group)) + match result { + Ok(Ok(serve)) => Ok(serve), + Ok(Err(status)) => Err(status.into()), + Err(_) => Err(StatusCode::GATEWAY_TIMEOUT.into()), + } } struct ServeGroup { - group: moq_lite::GroupConsumer, + group: Option, frame: Option, + deadline: tokio::time::Instant, } impl ServeGroup { - fn new(group: moq_lite::GroupConsumer) -> Self { - Self { group, frame: None } - } - async fn next(&mut self) -> moq_lite::Result> { - loop { + while self.group.is_some() || self.frame.is_some() { if let Some(frame) = self.frame.as_mut() { - let data = frame.read_all().await?; + let data = tokio::time::timeout_at(self.deadline, frame.read_all()) + .await + .map_err(|_| moq_lite::Error::Timeout)?; self.frame.take(); - return Ok(Some(data)); + return Ok(Some(data?)); } - self.frame = self.group.next_frame().await?; - if self.frame.is_none() { - return Ok(None); + if let Some(group) = self.group.as_mut() { + self.frame = tokio::time::timeout_at(self.deadline, group.next_frame()) + .await + .map_err(|_| moq_lite::Error::Timeout)??; + if self.frame.is_none() { + self.group.take(); + } } } + + Ok(None) } }