Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion web-transport-proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ categories = ["network-programming", "web-programming"]
bytes = "1"
http = "1"
thiserror = "2"

sfv = "0.14.0"
tracing = "0.1.44"
# Just for AsyncRead and AsyncWrite traits
tokio = { version = "1", default-features = false, features = ["io-util"] }
url = "2"
233 changes: 227 additions & 6 deletions web-transport-proto/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ use super::{qpack, Frame, VarInt};

use thiserror::Error;

mod protocol_negotiation {
//! WebTransport sub-protocol negotiation,
//!
//! according to [draft 14](https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-14.html#section-3.3)

/// The header name for the available protocols, sent within the WebTransport Connect request.
pub const AVAILABLE_NAME: &str = "wt-available-protocols";
/// The header name for the selected protocol, sent within the WebTransport Connect response.
pub const SELECTED_NAME: &str = "wt-protocol";
}

// Errors that can occur during the connect request.
#[derive(Error, Debug, Clone)]
pub enum ConnectError {
Expand Down Expand Up @@ -47,6 +58,9 @@ pub enum ConnectError {
#[error("expected path header")]
WrongPath,

#[error("header parsing error: field={field}, error={error}")]
HeaderError { field: &'static str, error: String },

#[error("non-200 status: {0:?}")]
ErrorStatus(http::StatusCode),

Expand All @@ -61,8 +75,12 @@ impl From<std::io::Error> for ConnectError {
}

#[derive(Debug)]
#[cfg_attr(test, derive(Eq, PartialEq))]
pub struct ConnectRequest {
/// The URL to connect to.
pub url: Url,
/// The webtransport sub protocols requested (if any).
pub protocols: Vec<String>,
}

impl ConnectRequest {
Expand Down Expand Up @@ -102,9 +120,45 @@ impl ConnectRequest {
return Err(ConnectError::WrongProtocol(protocol.map(|s| s.to_string())));
}

let protocols = if let Some(protocols) = headers
.get(protocol_negotiation::AVAILABLE_NAME)
.and_then(|protocols| {
sfv::Parser::new(protocols)
.parse::<sfv::List>()
.inspect_err(|error| {
tracing::error!(
?error,
"Failed to parse protocols as structured header field"
);
})
// if parsing of the field fails, the spec says we should ignore it and continue
.ok()
}) {
let total_items = protocols.len();
let final_protocols: Vec<String> = protocols
.into_iter()
.filter_map(|item| match item {
sfv::ListEntry::Item(sfv::Item {
bare_item: sfv::BareItem::String(s),
..
}) => Some(s.to_string()),
_ => None,
})
.collect();
if final_protocols.len() != total_items {
// we had non-string items in the list, according to the spec
// we should ignore the entire list
Vec::with_capacity(0)
} else {
final_protocols
}
} else {
Vec::with_capacity(0)
};

let url = Url::parse(&format!("{scheme}://{authority}{path_and_query}"))?;

Ok(Self { url })
Ok(Self { url, protocols })
}

pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, ConnectError> {
Expand All @@ -123,7 +177,7 @@ impl ConnectRequest {
}
}

pub fn encode<B: BufMut>(&self, buf: &mut B) {
pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<(), ConnectError> {
let mut headers = qpack::Headers::default();
headers.set(":method", "CONNECT");
headers.set(":scheme", self.url.scheme());
Expand All @@ -134,6 +188,25 @@ impl ConnectRequest {
};
headers.set(":path", &path_and_query);
headers.set(":protocol", "webtransport");
if !self.protocols.is_empty() {
// generate a proper StructuredField List header of the protocols given
let mut items = Vec::new();
for protocol in &self.protocols {
items.push(sfv::ListEntry::Item(sfv::Item::new(
sfv::StringRef::from_str(protocol.as_str()).map_err(|err| {
ConnectError::HeaderError {
field: protocol_negotiation::AVAILABLE_NAME,
error: err.to_string(),
}
})?,
)));
}
let mut ser = sfv::ListSerializer::new();
ser.members(items.iter());
if let Some(protocols) = ser.finish() {
headers.set(protocol_negotiation::AVAILABLE_NAME, protocols.as_str());
}
}

// Use a temporary buffer so we can compute the size.
let mut tmp = Vec::new();
Expand All @@ -143,19 +216,24 @@ impl ConnectRequest {
Frame::HEADERS.encode(buf);
size.encode(buf);
buf.put_slice(&tmp);
Ok(())
}

pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), ConnectError> {
let mut buf = BytesMut::new();
self.encode(&mut buf);
self.encode(&mut buf)?;
stream.write_all_buf(&mut buf).await?;
Ok(())
}
}

#[derive(Debug)]
#[cfg_attr(test, derive(Eq, PartialEq))]
pub struct ConnectResponse {
/// The status code of the response.
pub status: http::status::StatusCode,
/// The webtransport sub protocol selected by the server, if any
pub protocol: Option<String>,
}

impl ConnectResponse {
Expand All @@ -178,7 +256,19 @@ impl ConnectResponse {
o => return Err(ConnectError::WrongStatus(o)),
};

Ok(Self { status })
let protocol = headers
.get(protocol_negotiation::SELECTED_NAME)
.and_then(|s| {
let item = sfv::Parser::new(s)
.parse::<sfv::Item>()
.map_err(|error| {
tracing::error!(?error, "Failed to parse protocol header item. ignoring");
})
.ok()?;
item.bare_item.as_string().map(|rf| rf.to_string())
});

Ok(Self { status, protocol })
}

pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, ConnectError> {
Expand All @@ -197,10 +287,26 @@ impl ConnectResponse {
}
}

pub fn encode<B: BufMut>(&self, buf: &mut B) {
pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<(), ConnectError> {
let mut headers = qpack::Headers::default();
headers.set(":status", self.status.as_str());
headers.set("sec-webtransport-http3-draft", "draft02");
if let Some(protocol) = &self.protocol {
let serialized_item = sfv::ItemSerializer::new()
.bare_item(
sfv::StringRef::from_str(protocol.as_str()).map_err(|error| {
ConnectError::HeaderError {
field: protocol_negotiation::SELECTED_NAME,
error: error.to_string(),
}
})?,
)
.finish();
headers.set(
protocol_negotiation::SELECTED_NAME,
serialized_item.as_str(),
);
}

// Use a temporary buffer so we can compute the size.
let mut tmp = Vec::new();
Expand All @@ -210,12 +316,127 @@ impl ConnectResponse {
Frame::HEADERS.encode(buf);
size.encode(buf);
buf.put_slice(&tmp);
Ok(())
}

pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), ConnectError> {
let mut buf = BytesMut::new();
self.encode(&mut buf);
self.encode(&mut buf)?;
stream.write_all_buf(&mut buf).await?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use http::StatusCode;

use super::*;

#[test]
pub fn test_request_encode_decode_simple() {
let response = ConnectRequest {
url: "https://example.com".parse().unwrap(),
protocols: vec![],
};
let mut buf = BytesMut::new();
response.encode(&mut buf).unwrap();
let decoded = ConnectRequest::decode(&mut buf).unwrap();
assert_eq!(response, decoded);
}

#[test]
pub fn test_request_encode_decode_with_protocol() {
let response = ConnectRequest {
url: "https://example.com".parse().unwrap(),
protocols: vec!["protocol-1".to_string(), "protocol-2".to_string()],
};
let mut buf = BytesMut::new();
response.encode(&mut buf).unwrap();
let decoded = ConnectRequest::decode(&mut buf).unwrap();
assert_eq!(response, decoded);
}

#[test]
pub fn test_request_encode_decode_with_protocol_with_quotes() {
let response = ConnectRequest {
url: "https://example.com".parse().unwrap(),
protocols: vec!["protocol-\"1\"".to_string(), "protocol-'2'".to_string()],
};
let mut buf = BytesMut::new();
response.encode(&mut buf).unwrap();
let decoded = ConnectRequest::decode(&mut buf).unwrap();
assert_eq!(response, decoded);
}

#[test]
pub fn test_request_encode_decode_with_non_compliant_protocol() {
let response = ConnectRequest {
url: "https://example.com".parse().unwrap(),
protocols: vec!["protocol-🐕".to_string()],
};
let mut buf = BytesMut::new();
let resp = response.encode(&mut buf);
assert!(resp.is_err(), "non ascii must fail");
assert!(matches!(
resp,
Err(ConnectError::HeaderError {
field: protocol_negotiation::AVAILABLE_NAME,
..
})
));
}
#[test]
pub fn test_response_encode_decode_simple() {
let response = ConnectResponse {
status: StatusCode::ACCEPTED,
protocol: None,
};
let mut buf = BytesMut::new();
response.encode(&mut buf).unwrap();
let decoded = ConnectResponse::decode(&mut buf).unwrap();
assert_eq!(response, decoded);
}

#[test]
pub fn test_response_encode_decode_with_protocol() {
let response = ConnectResponse {
status: StatusCode::ACCEPTED,
protocol: Some("proto".to_string()),
};
let mut buf = BytesMut::new();
response.encode(&mut buf).unwrap();
let decoded = ConnectResponse::decode(&mut buf).unwrap();
assert_eq!(response, decoded);
}

#[test]
pub fn test_response_encode_decode_with_protocol_with_quotes() {
let response = ConnectResponse {
status: StatusCode::ACCEPTED,
protocol: Some("'proto'-\"1\"".to_string()),
};
let mut buf = BytesMut::new();
response.encode(&mut buf).unwrap();
let decoded = ConnectResponse::decode(&mut buf).unwrap();
assert_eq!(response, decoded);
}

#[test]
pub fn test_response_encode_decode_with_noncompliant_protocol() {
let response = ConnectResponse {
status: StatusCode::ACCEPTED,
protocol: Some("proto-😅".to_string()),
};
let mut buf = BytesMut::new();
let resp = response.encode(&mut buf);
assert!(resp.is_err(), "non ascii must fail");
assert!(matches!(
resp,
Err(ConnectError::HeaderError {
field: protocol_negotiation::SELECTED_NAME,
..
})
));
}
}
10 changes: 8 additions & 2 deletions web-transport-quiche/src/h3/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ impl Connect {
///
/// This is called by the server to accept or reject the connection.
pub async fn respond(&mut self, status: http::StatusCode) -> Result<(), ConnectError> {
let response = ConnectResponse { status };
let response = ConnectResponse {
status,
protocol: None,
};
tracing::debug!(?response, "sending CONNECT");
response.write(&mut self.send).await?;

Expand All @@ -77,7 +80,10 @@ impl Connect {
let (mut send, mut recv) = conn.open_bi().await?;

// Create a new CONNECT request that we'll send using HTTP/3
let request = ConnectRequest { url };
let request = ConnectRequest {
url,
protocols: vec![],
};

tracing::debug!(?request, "sending CONNECT");
request.write(&mut send).await?;
Expand Down
1 change: 1 addition & 0 deletions web-transport-quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ clap = { version = "4", features = ["derive"] }
rustls-pemfile = "2"
tokio = { version = "1", features = ["full"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
rcgen = "0.14.6"
Loading