From 04ba98675808f61ba1f625494255891c09681efb Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Sat, 14 Feb 2026 11:22:59 +0200 Subject: [PATCH 1/2] fix OOM in fast networks Currently the download buffer is bounded by item count, this is ok if chunks are 16KB, However if we get 256KB chunks for example, we go over the 128MB limit. This patch uses ByteBoundChannel to enforce the byte limit, even we are under the chunk size limit Signed-off-by: Benny Zlotnik Assisted-by: claude-opus-4.5 --- src/fls/byte_channel.rs | 272 ++++++++++++++++++++++++++++++++++++++++ src/fls/from_url.rs | 18 ++- src/fls/mod.rs | 1 + src/fls/oci/from_oci.rs | 38 +++--- src/fls/stream_utils.rs | 32 ++++- 5 files changed, 332 insertions(+), 29 deletions(-) create mode 100644 src/fls/byte_channel.rs diff --git a/src/fls/byte_channel.rs b/src/fls/byte_channel.rs new file mode 100644 index 0000000..ac1949d --- /dev/null +++ b/src/fls/byte_channel.rs @@ -0,0 +1,272 @@ +/// Byte-bounded channel wrapper for memory-safe streaming +/// +/// Wraps `mpsc::channel` with a `tokio::sync::Semaphore` to bound total +/// buffered bytes rather than item count. This prevents OOM when chunk +/// sizes vary (e.g., reqwest delivering 64-256KB chunks on fast networks). +use std::sync::Arc; +use tokio::sync::{mpsc, Semaphore}; + +/// Trait for items that know their byte size. +pub trait SizedItem { + fn byte_size(&self) -> usize; +} + +impl SizedItem for bytes::Bytes { + fn byte_size(&self) -> usize { + self.len() + } +} + +/// Sender half of a byte-bounded channel. +/// +/// Acquires semaphore permits equal to `chunk.byte_size()` before sending, +/// ensuring total buffered bytes never exceeds `max_bytes`. +pub struct ByteBoundedSender { + inner: mpsc::Sender, + semaphore: Arc, + max_bytes: usize, +} + +impl ByteBoundedSender { + /// Send an item, blocking (async) until enough byte budget is available. + /// + /// Acquires `min(item.byte_size(), max_bytes)` permits so a single + /// oversized chunk can still pass through without deadlocking. + pub async fn send(&self, item: T) -> Result<(), mpsc::error::SendError> { + let permits_needed = item.byte_size().min(self.max_bytes); + + let permits_needed_u32 = permits_needed as u32; + + // acquire_many_owned returns OwnedSemaphorePermit which we intentionally + // forget — the receiver side adds permits back after consuming the item. + let permit = self + .semaphore + .acquire_many(permits_needed_u32) + .await + .expect("semaphore closed unexpectedly"); + permit.forget(); + + self.inner.send(item).await + } +} + +/// Receiver half of a byte-bounded channel. +/// +/// Returns semaphore permits after receiving each item, freeing byte budget +/// for the sender. +pub struct ByteBoundedReceiver { + inner: mpsc::Receiver, + semaphore: Arc, + max_bytes: usize, +} + +impl ByteBoundedReceiver { + /// Receive an item asynchronously, releasing byte budget on success. + pub async fn recv(&mut self) -> Option { + let item = self.inner.recv().await?; + let to_release = item.byte_size().min(self.max_bytes); + self.semaphore.add_permits(to_release); + Some(item) + } + + /// Receive an item synchronously (for use in `spawn_blocking`), + /// releasing byte budget on success. + pub fn blocking_recv(&mut self) -> Option { + let item = self.inner.blocking_recv()?; + let to_release = item.byte_size().min(self.max_bytes); + self.semaphore.add_permits(to_release); + Some(item) + } +} + +/// Create a byte-bounded channel. +/// +/// - `max_bytes`: maximum total bytes buffered at any time (must be ≤ u32::MAX) +/// - `max_items`: underlying mpsc channel item capacity (prevents unbounded item queuing) +/// +/// # Panics +/// +/// Panics if `max_bytes` exceeds `u32::MAX` (4,294,967,295 bytes ≈ 4GB). +/// This limit exists because the underlying semaphore uses u32 permit counts. +pub fn byte_bounded_channel( + max_bytes: usize, + max_items: usize, +) -> (ByteBoundedSender, ByteBoundedReceiver) { + // Guard against overflow in send() method's permits_needed as u32 cast + if max_bytes > u32::MAX as usize { + panic!( + "max_bytes ({}) exceeds u32::MAX ({}). Maximum supported buffer size is ~4GB.", + max_bytes, + u32::MAX + ); + } + + let (tx, rx) = mpsc::channel::(max_items); + let semaphore = Arc::new(Semaphore::new(max_bytes)); + + let sender = ByteBoundedSender { + inner: tx, + semaphore: semaphore.clone(), + max_bytes, + }; + + let receiver = ByteBoundedReceiver { + inner: rx, + semaphore, + max_bytes, + }; + + (sender, receiver) +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use std::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn test_basic_send_receive() { + let (tx, mut rx) = byte_bounded_channel::(1024, 10); + + let data = Bytes::from_static(b"hello"); + tx.send(data.clone()).await.unwrap(); + + let received = rx.recv().await.unwrap(); + assert_eq!(received, data); + } + + #[tokio::test] + async fn test_byte_limit_enforcement() { + // 100-byte limit, 5 item capacity + let (tx, _rx) = byte_bounded_channel::(100, 5); + + // Send 80 bytes (should succeed) + let chunk1 = Bytes::from(vec![1u8; 80]); + tx.send(chunk1).await.unwrap(); + + // Send 20 bytes (should succeed, total = 100) + let chunk2 = Bytes::from(vec![2u8; 20]); + tx.send(chunk2).await.unwrap(); + + // Try to send 1 more byte (should block) + let chunk3 = Bytes::from(vec![3u8; 1]); + let send_future = tx.send(chunk3); + + // Should timeout because buffer is full + assert!(timeout(Duration::from_millis(50), send_future) + .await + .is_err()); + } + + #[tokio::test] + async fn test_permits_released_after_recv() { + let (tx, mut rx) = byte_bounded_channel::(100, 5); + + // Fill buffer to capacity + let chunk1 = Bytes::from(vec![1u8; 60]); + let chunk2 = Bytes::from(vec![2u8; 40]); + tx.send(chunk1).await.unwrap(); + tx.send(chunk2).await.unwrap(); + + // Buffer should be full, next send should block + let chunk3 = Bytes::from(vec![3u8; 1]); + let send_future = tx.send(chunk3.clone()); + assert!(timeout(Duration::from_millis(50), send_future) + .await + .is_err()); + + // Consume one chunk, freeing 60 bytes + let _received = rx.recv().await.unwrap(); + + // Now the blocked send should succeed + let send_future = tx.send(chunk3); + assert!(timeout(Duration::from_millis(50), send_future) + .await + .is_ok()); + } + + #[tokio::test] + async fn test_oversized_chunk_handling() { + // 50-byte limit + let (tx, mut rx) = byte_bounded_channel::(50, 5); + + // Send 100-byte chunk (larger than limit) + let big_chunk = Bytes::from(vec![1u8; 100]); + + // Should succeed (acquires min(100, 50) = 50 permits) + tx.send(big_chunk.clone()).await.unwrap(); + + // Should be able to receive it + let received = rx.recv().await.unwrap(); + assert_eq!(received, big_chunk); + } + + #[tokio::test] + async fn test_multiple_chunk_sizes() { + let (tx, mut rx) = byte_bounded_channel::(1000, 100); + + let chunks = vec![ + Bytes::from(vec![1u8; 100]), // Small + Bytes::from(vec![2u8; 500]), // Medium + Bytes::from(vec![3u8; 300]), // Large + Bytes::from(vec![4u8; 50]), // Tiny + ]; + + // Send all chunks + for chunk in &chunks { + tx.send(chunk.clone()).await.unwrap(); + } + + // Receive and verify + for expected in &chunks { + let received = rx.recv().await.unwrap(); + assert_eq!(received, *expected); + } + } + + #[tokio::test] + async fn test_channel_closure() { + let (tx, mut rx) = byte_bounded_channel::(100, 5); + + tx.send(Bytes::from_static(b"data")).await.unwrap(); + drop(tx); // Close sender + + // Should receive the sent data + let received = rx.recv().await.unwrap(); + assert_eq!(received, Bytes::from_static(b"data")); + + // Next recv should return None (channel closed) + assert!(rx.recv().await.is_none()); + } + + #[tokio::test] + async fn test_blocking_recv() { + let (tx, mut rx) = byte_bounded_channel::(100, 5); + + // Test in spawn_blocking since blocking_recv is sync + let handle = tokio::task::spawn_blocking(move || { + // This should block until data is available + rx.blocking_recv() + }); + + // Give it a moment to start blocking + tokio::time::sleep(Duration::from_millis(10)).await; + + // Send data + tx.send(Bytes::from_static(b"test")).await.unwrap(); + + // Should now unblock and return the data + let result = handle.await.unwrap(); + assert_eq!(result.unwrap(), Bytes::from_static(b"test")); + } + + #[test] + #[should_panic(expected = "max_bytes (4294967296) exceeds u32::MAX")] + fn test_max_bytes_overflow_guard() { + // Try to create a channel with max_bytes > u32::MAX + let oversized = (u32::MAX as usize) + 1; + let _ = byte_bounded_channel::(oversized, 100); + } +} diff --git a/src/fls/from_url.rs b/src/fls/from_url.rs index 632e891..e5653a7 100644 --- a/src/fls/from_url.rs +++ b/src/fls/from_url.rs @@ -5,6 +5,7 @@ use tokio::sync::mpsc; use tokio::task::JoinHandle; use crate::fls::block_writer::AsyncBlockWriter; +use crate::fls::byte_channel::byte_bounded_channel; use crate::fls::decompress::{spawn_stderr_reader, start_decompressor_process}; use crate::fls::download_error::DownloadError; use crate::fls::error_handling::process_error_messages; @@ -349,21 +350,18 @@ pub async fn flash_from_url( use futures_util::StreamExt; - // Calculate buffer capacity (shared across all retry attempts) + // Create byte-bounded download buffer (shared across all retry attempts) let buffer_size_mb = options.common.buffer_size_mb; - // HTTP chunks from reqwest are typically 8-32 KB, not 64 KB - // To ensure we get the full buffer size, use a conservative estimate - let avg_chunk_size_kb = 16; // From common obvervation: 16kb - let buffer_capacity = (buffer_size_mb * 1024) / avg_chunk_size_kb; - let buffer_capacity = buffer_capacity.max(1000); // At least 1000 chunks + let max_buffer_bytes = buffer_size_mb * 1024 * 1024; println!( - "Using download buffer: {} MB (capacity: {} chunks, ~{} KB per chunk)", - buffer_size_mb, buffer_capacity, avg_chunk_size_kb + "Using download buffer: {} MB (byte-bounded)", + buffer_size_mb ); - // Create persistent bounded channel for download buffering (lives across retries) - let (buffer_tx, mut buffer_rx) = mpsc::channel::(buffer_capacity); + // Create persistent byte-bounded channel for download buffering (lives across retries) + // max_items=4096 prevents unbounded item queuing; byte budget is the real bound + let (buffer_tx, mut buffer_rx) = byte_bounded_channel::(max_buffer_bytes, 4096); // Channels for tracking bytes actually written to decompressor let (decompressor_written_progress_tx, mut decompressor_written_progress_rx) = diff --git a/src/fls/mod.rs b/src/fls/mod.rs index 67736a5..5ad8016 100644 --- a/src/fls/mod.rs +++ b/src/fls/mod.rs @@ -1,6 +1,7 @@ // Module declarations pub mod automotive; mod block_writer; +pub mod byte_channel; pub(crate) mod compression; mod decompress; mod download_error; diff --git a/src/fls/oci/from_oci.rs b/src/fls/oci/from_oci.rs index 56f86c6..d0d135c 100644 --- a/src/fls/oci/from_oci.rs +++ b/src/fls/oci/from_oci.rs @@ -14,6 +14,8 @@ use tokio::io::AsyncWriteExt; use tokio::sync::mpsc; use xz2::read::XzDecoder; +use crate::fls::byte_channel::{byte_bounded_channel, ByteBoundedReceiver, ByteBoundedSender}; + use super::manifest::{LayerCompression, Manifest}; use super::reference::ImageReference; use super::registry::RegistryClient; @@ -35,7 +37,7 @@ const OCI_TITLE_ANNOTATION: &str = "org.opencontainers.image.title"; /// Parameters for download coordination functions struct DownloadCoordinationParams { - http_tx: mpsc::Sender, + http_tx: ByteBoundedSender, decompressed_progress_rx: mpsc::UnboundedReceiver, written_progress_rx: mpsc::UnboundedReceiver, decompressor_written_progress_rx: mpsc::UnboundedReceiver, @@ -50,7 +52,7 @@ struct DownloadContext { /// Parameters for raw disk download coordination struct RawDiskDownloadParams { - http_tx: mpsc::Sender, + http_tx: ByteBoundedSender, writer_handle: tokio::task::JoinHandle>, external_decompressor: Option, decompressed_progress_rx: mpsc::UnboundedReceiver, @@ -73,8 +75,8 @@ struct ExternalDecompressorPipeline { /// Components returned by pipeline setup struct TarPipelineComponents { - http_tx: mpsc::Sender, - http_rx: mpsc::Receiver, + http_tx: ByteBoundedSender, + http_rx: ByteBoundedReceiver, tar_tx: mpsc::Sender>, decompressed_progress_rx: mpsc::UnboundedReceiver, written_progress_rx: mpsc::UnboundedReceiver, @@ -880,12 +882,14 @@ async fn setup_tar_processing_pipeline( buffer_size_mb: usize, buffer_capacity: usize, ) -> Result> { + let max_buffer_bytes = buffer_size_mb * 1024 * 1024; println!( - "Using download buffer: {} MB (capacity: {} chunks)", - buffer_size_mb, buffer_capacity + "Using download buffer: {} MB (byte-bounded)", + buffer_size_mb ); - let (http_tx, http_rx) = mpsc::channel::(buffer_capacity); + let (http_tx, http_rx) = + byte_bounded_channel::(max_buffer_bytes, buffer_capacity); // Channel for tar entry data -> decompressor stdin let (tar_tx, mut tar_rx) = mpsc::channel::>(16); // 16 * 8MB = 128MB buffer @@ -1213,7 +1217,7 @@ async fn coordinate_download_and_processing( /// Setup external decompressor pipeline for XZ compression async fn setup_external_decompressor_pipeline( - http_rx: mpsc::Receiver, + http_rx: ByteBoundedReceiver, block_writer: AsyncBlockWriter, decompressed_progress_tx: mpsc::UnboundedSender, debug: bool, @@ -1262,7 +1266,7 @@ async fn setup_external_decompressor_pipeline( let stdin_writer_handle = { tokio::task::spawn_blocking(move || { use std::io::Write as _; - let reader = ChannelReader::new(http_rx); + let reader = ChannelReader::new_byte_bounded(http_rx); let mut reader = reader; let mut stdin = stdin_fd; let mut buffer = vec![0u8; 1024 * 1024]; // 1MB chunks @@ -1338,7 +1342,7 @@ async fn setup_external_decompressor_pipeline( /// Setup in-process decompression pipeline for Gzip or None compression async fn setup_inprocess_decompression_pipeline( - http_rx: mpsc::Receiver, + http_rx: ByteBoundedReceiver, block_writer: AsyncBlockWriter, decompressed_progress_tx: mpsc::UnboundedSender, compression_type: Compression, @@ -1353,7 +1357,7 @@ async fn setup_inprocess_decompression_pipeline( // Spawn blocking task: read, decompress, send to async channel let reader_handle = tokio::task::spawn_blocking(move || { - let reader = ChannelReader::new(http_rx); + let reader = ChannelReader::new_byte_bounded(http_rx); // Apply in-process gzip decompression if needed let processed_reader: Box = match compression_type { @@ -2283,11 +2287,13 @@ async fn flash_raw_disk_image_directly( options.common.write_buffer_size_mb, )?; - // Set up streaming pipeline using channels + // Set up byte-bounded streaming pipeline let buffer_size_mb = options.common.buffer_size_mb; - let buffer_capacity = ((buffer_size_mb * 1024) / 16).max(1000); // 16KB average chunk size + let max_buffer_bytes = buffer_size_mb * 1024 * 1024; + let buffer_capacity = ((buffer_size_mb * 1024) / 16).max(1000); // item cap for mpsc - let (http_tx, http_rx) = mpsc::channel::(buffer_capacity); + let (http_tx, http_rx) = + byte_bounded_channel::(max_buffer_bytes, buffer_capacity); let (decompressed_progress_tx, decompressed_progress_rx) = mpsc::unbounded_channel::(); // For gzip and none, we can decompress in-process and write directly to block writer @@ -2341,14 +2347,14 @@ async fn flash_raw_disk_image_directly( /// Simple tar archive extraction without the complex buffering logic fn extract_tar_archive_from_stream( - http_rx: mpsc::Receiver, + http_rx: ByteBoundedReceiver, tar_tx: mpsc::Sender>, file_pattern: Option<&str>, compression: LayerCompression, compression_type: Compression, debug: bool, ) -> Result<(), String> { - let reader = ChannelReader::new(http_rx); + let reader = ChannelReader::new_byte_bounded(http_rx); // Handle layer compression before tar extraction // Use both manifest compression and content-detected compression diff --git a/src/fls/stream_utils.rs b/src/fls/stream_utils.rs index 809d4ff..024c2fc 100644 --- a/src/fls/stream_utils.rs +++ b/src/fls/stream_utils.rs @@ -6,21 +6,47 @@ use bytes::Bytes; use std::io::Read; use tokio::sync::mpsc; +use crate::fls::byte_channel::ByteBoundedReceiver; + +/// Abstraction over plain and byte-bounded receivers +enum ReceiverVariant { + Plain(mpsc::Receiver), + ByteBounded(ByteBoundedReceiver), +} + +impl ReceiverVariant { + fn blocking_recv(&mut self) -> Option { + match self { + ReceiverVariant::Plain(rx) => rx.blocking_recv(), + ReceiverVariant::ByteBounded(rx) => rx.blocking_recv(), + } + } +} + /// Reader that pulls bytes from a tokio mpsc channel /// /// This bridges async HTTP streaming with synchronous readers /// like tar::Archive or flate2::GzDecoder. pub struct ChannelReader { - rx: mpsc::Receiver, + rx: ReceiverVariant, current: Option, offset: usize, } impl ChannelReader { - /// Create a new ChannelReader from an mpsc receiver + /// Create a new ChannelReader from a plain mpsc receiver pub fn new(rx: mpsc::Receiver) -> Self { Self { - rx, + rx: ReceiverVariant::Plain(rx), + current: None, + offset: 0, + } + } + + /// Create a new ChannelReader from a byte-bounded receiver + pub fn new_byte_bounded(rx: ByteBoundedReceiver) -> Self { + Self { + rx: ReceiverVariant::ByteBounded(rx), current: None, offset: 0, } From 764efde526237bd4817d1289e4598e3545d09c5d Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Sat, 14 Feb 2026 11:23:09 +0200 Subject: [PATCH 2/2] add tests Signed-off-by: Benny Zlotnik --- tests/byte_bounded_memory.rs | 198 +++++++++++++++++++++++++++++++++++ tests/common/mod.rs | 2 + 2 files changed, 200 insertions(+) create mode 100644 tests/byte_bounded_memory.rs diff --git a/tests/byte_bounded_memory.rs b/tests/byte_bounded_memory.rs new file mode 100644 index 0000000..ae09a22 --- /dev/null +++ b/tests/byte_bounded_memory.rs @@ -0,0 +1,198 @@ +// Integration tests for byte-bounded memory behavior + +use bytes::Bytes; +use fls::fls::byte_channel::byte_bounded_channel; +use std::time::Duration; +use tokio::time::timeout; + +/// Test that backpressure kicks in with large chunks +#[tokio::test] +async fn test_backpressure_with_large_chunks() { + let max_bytes = 128 * 1024; // 128KB limit + let (tx, mut rx) = byte_bounded_channel::(max_bytes, 100); + + // Step 1: Fill buffer to capacity + let chunk = Bytes::from(vec![1u8; 128 * 1024]); // Exactly the buffer size + tx.send(chunk).await.unwrap(); + + // Step 2: Try to send another chunk - this should block due to backpressure + let blocking_chunk = Bytes::from(vec![2u8; 64 * 1024]); + let send_task = tokio::spawn(async move { + tx.send(blocking_chunk).await.unwrap(); + "completed" + }); + + // Step 3: Verify the send is blocked (structural check, not timing) + tokio::time::sleep(Duration::from_millis(10)).await; + assert!( + !send_task.is_finished(), + "Send should be blocked by backpressure" + ); + + // Step 4: Consume the buffered chunk to free space + let _consumed = rx.recv().await.unwrap(); + + // Step 5: Now the blocked send should complete + let result = timeout(Duration::from_millis(100), send_task).await; + assert!(result.is_ok(), "Send should unblock after freeing space"); + assert_eq!(result.unwrap().unwrap(), "completed"); +} + +/// Test that small chunks don't artificially limit throughput +#[tokio::test] +async fn test_small_chunks_high_throughput() { + let max_bytes = 64 * 1024; // 64KB limit + let (tx, mut rx) = byte_bounded_channel::(max_bytes, 10000); + + let start_time = std::time::Instant::now(); + + // Producer: send many small chunks quickly + let producer = tokio::spawn(async move { + for i in 0..1000 { + let chunk = Bytes::from(vec![i as u8; 32]); // 32-byte chunks + tx.send(chunk).await.unwrap(); + } + }); + + // Consumer: receive all chunks + let consumer = tokio::spawn(async move { + let mut count = 0; + while let Some(_chunk) = rx.recv().await { + count += 1; + if count >= 1000 { + break; + } + } + count + }); + + let (_, received_count) = tokio::join!(producer, consumer); + let elapsed = start_time.elapsed(); + + assert_eq!(received_count.unwrap(), 1000); + // Should complete quickly (small chunks shouldn't be bottlenecked) + assert!( + elapsed < Duration::from_secs(1), + "Small chunks took too long: {:?}", + elapsed + ); +} + +/// Test backpressure behavior with mixed chunk sizes +#[tokio::test] +async fn test_backpressure_with_mixed_sizes() { + let max_bytes = 256 * 1024; // 256KB limit + let (tx, mut rx) = byte_bounded_channel::(max_bytes, 100); + + // Fill buffer with mixed-size chunks + tx.send(Bytes::from(vec![1u8; 128 * 1024])).await.unwrap(); // 128KB + tx.send(Bytes::from(vec![2u8; 64 * 1024])).await.unwrap(); // 64KB + tx.send(Bytes::from(vec![3u8; 32 * 1024])).await.unwrap(); // 32KB + // Total: 224KB (getting close to 256KB limit) + + // Try to send another 64KB chunk - this should block + let blocking_chunk = Bytes::from(vec![4u8; 64 * 1024]); // Would exceed limit + let send_task = tokio::spawn(async move { + tx.send(blocking_chunk).await.unwrap(); + "completed" + }); + + // Verify backpressure is working + tokio::time::sleep(Duration::from_millis(10)).await; + assert!( + !send_task.is_finished(), + "Send should be blocked when buffer would exceed limit" + ); + + // Consume the 128KB chunk to free space + let consumed = rx.recv().await.unwrap(); + assert_eq!(consumed.len(), 128 * 1024); + + // Now the blocked send should complete + let result = timeout(Duration::from_millis(100), send_task).await; + assert!(result.is_ok(), "Send should unblock after consuming data"); + assert_eq!(result.unwrap().unwrap(), "completed"); +} + +/// Test that oversized single chunks don't deadlock +#[tokio::test] +async fn test_oversized_chunk_no_deadlock() { + let max_bytes = 100 * 1024; // 100KB limit + let (tx, mut rx) = byte_bounded_channel::(max_bytes, 10); + + // Send a chunk larger than the buffer + let oversized_chunk = Bytes::from(vec![42u8; 256 * 1024]); // 256KB chunk + + let producer = tokio::spawn(async move { + tx.send(oversized_chunk.clone()).await.unwrap(); + oversized_chunk + }); + + let consumer = tokio::spawn(async move { rx.recv().await.unwrap() }); + + // Should complete without deadlock + let result = timeout(Duration::from_secs(1), async move { + let (sent, received) = tokio::join!(producer, consumer); + (sent.unwrap(), received.unwrap()) + }) + .await; + + assert!(result.is_ok(), "Oversized chunk should not deadlock"); + let (sent, received) = result.unwrap(); + assert_eq!(sent, received); +} + +/// Test property: regardless of chunk pattern, all data flows through correctly +#[tokio::test] +async fn test_chunk_size_independence() { + let max_bytes = 256 * 1024; // 256KB limit + + // Test different chunk patterns + let test_cases = vec![ + ("uniform_small", vec![4096; 100]), // 100 × 4KB + ("uniform_large", vec![64 * 1024; 10]), // 10 × 64KB + ("mixed", vec![1024, 32 * 1024, 1024, 128 * 1024, 1024]), // Mixed sizes + ("single_large", vec![200 * 1024]), // 1 × 200KB + ]; + + for (name, chunk_sizes) in test_cases { + println!("Testing chunk pattern: {}", name); + + let (tx, mut rx) = byte_bounded_channel::(max_bytes, 1000); + + // Producer: send the chunk pattern + let chunk_pattern = chunk_sizes.clone(); + let producer = tokio::spawn(async move { + let mut total_bytes = 0; + for (i, size) in chunk_pattern.iter().enumerate() { + let chunk = Bytes::from(vec![i as u8; *size]); + total_bytes += chunk.len(); + tx.send(chunk).await.unwrap(); + } + total_bytes + }); + + // Consumer: receive all chunks + let consumer = tokio::spawn(async move { + let mut total_bytes = 0; + let mut chunk_count = 0; + while let Some(chunk) = rx.recv().await { + total_bytes += chunk.len(); + chunk_count += 1; + + if chunk_count >= chunk_sizes.len() { + break; + } + } + total_bytes + }); + + let (sent_bytes, received_bytes) = tokio::join!(producer, consumer); + assert_eq!( + sent_bytes.unwrap(), + received_bytes.unwrap(), + "All bytes should flow through for pattern: {}", + name + ); + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 3790a71..5d62adb 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -5,6 +5,7 @@ use std::io::Write; use xz2::write::XzEncoder; /// Generate deterministic test data of a given size +#[allow(dead_code)] pub fn create_test_data(size: usize) -> Vec { // Create a repeating pattern for easier debugging let pattern = b"TESTDATA"; @@ -28,6 +29,7 @@ pub fn compress_xz(data: &[u8]) -> Vec { } /// Compress data using gzip compression +#[allow(dead_code)] pub fn compress_gz(data: &[u8]) -> Vec { let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); encoder