From fc426934be865a782d8a1722f42eac2994d31e0c Mon Sep 17 00:00:00 2001 From: Igor Martayan Date: Mon, 18 Aug 2025 18:00:21 +0200 Subject: [PATCH 1/2] Batch processing of short records by packing them together --- src/filter.rs | 283 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 224 insertions(+), 59 deletions(-) diff --git a/src/filter.rs b/src/filter.rs index f89047c..74ac30a 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -1,5 +1,6 @@ use crate::{FilterConfig, index::load_minimizer_hashes}; use anyhow::{Context, Result}; +use core::ops::Range; use flate2::write::GzEncoder; use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle}; use liblzma::write::XzEncoder; @@ -23,6 +24,10 @@ use zstd::stream::write::Encoder as ZstdEncoder; const OUTPUT_BUFFER_SIZE: usize = 8 * 1024 * 1024; // Opt: 8MB output buffer const DEFAULT_BUFFER_SIZE: usize = 64 * 1024; +const SEQ_LEN_THRESHOLD: usize = 8000; // Seq batch size (in bp) for simd-minimizers +const DEFAULT_SEQ_BUFFER_SIZE: usize = 30000; +const DEFAULT_NUM_RECORDS_PER_BATCH: usize = SEQ_LEN_THRESHOLD / 100; // Estimated number of records per batch (assuming 100bp short reads) +const DEFAULT_ID_BUFFER_SIZE: usize = DEFAULT_NUM_RECORDS_PER_BATCH * 80; type BoxedWriter = Box; @@ -198,6 +203,7 @@ struct FilterProcessor { local_buffer2: Vec, // Second buffer for paired output local_stats: ProcessingStats, filter_buffers: FilterBuffers, + records: RecordBuffer, // Global state global_writer: Arc>, @@ -217,12 +223,98 @@ struct ProcessingStats { output_seq_counter: u64, } +#[derive(Clone)] +pub struct MinimalRecord<'a> { + id: &'a [u8], + seq: &'a [u8], + qual: Option<&'a [u8]>, +} + +impl Record for MinimalRecord<'_> { + fn id(&self) -> &[u8] { + self.id + } + + fn seq(&self) -> std::borrow::Cow<[u8]> { + std::borrow::Cow::Borrowed(self.seq) + } + + fn seq_raw(&self) -> &[u8] { + self.seq + } + + fn qual(&self) -> Option<&[u8]> { + self.qual + } +} + +#[derive(Default, Clone)] +pub struct RecordBuffer { + pub id_buffer: Vec, + pub seq_buffer: Vec, + pub qual_buffer: Option>, + pub id_seq_ends: Vec<(usize, usize)>, +} + +impl RecordBuffer { + pub fn new() -> Self { + Self { + id_buffer: Vec::with_capacity(DEFAULT_ID_BUFFER_SIZE), + seq_buffer: Vec::with_capacity(DEFAULT_SEQ_BUFFER_SIZE), + id_seq_ends: Vec::with_capacity(DEFAULT_NUM_RECORDS_PER_BATCH), + qual_buffer: None, + } + } + + pub fn push_record(&mut self, record: &Rf) { + self.id_buffer.extend_from_slice(record.id()); + self.seq_buffer.extend_from_slice(&record.seq()); + self.id_seq_ends + .push((self.id_buffer.len(), self.seq_buffer.len())); + if let Some(qual) = record.qual() { + if self.qual_buffer.is_none() { + self.qual_buffer = Some(Vec::with_capacity(DEFAULT_BUFFER_SIZE)); + } + self.qual_buffer.as_mut().unwrap().extend_from_slice(qual); + } + } + + pub fn clear(&mut self) { + self.id_buffer.clear(); + self.seq_buffer.clear(); + self.id_seq_ends.clear(); + if let Some(qual_buffer) = self.qual_buffer.as_mut() { + qual_buffer.clear(); + } + } + + pub fn iter(&self) -> impl ExactSizeIterator { + let mut id_start = 0; + let mut seq_start = 0; + + self.id_seq_ends.iter().map(move |&(id_end, seq_end)| { + let rec = MinimalRecord { + id: &self.id_buffer[id_start..id_start], + seq: &self.seq_buffer[seq_start..seq_end], + qual: self + .qual_buffer + .as_ref() + .map(|qual_buffer| &qual_buffer[seq_start..seq_end]), + }; + id_start = id_end; + seq_start = seq_end; + rec + }) + } +} + #[derive(Default, Clone)] struct FilterBuffers { - packed_seq: packed_seq::PackedSeqVec, - invalid_mask: Vec, - positions: Vec, - minimizer_values: Vec, + pub packed_seq: packed_seq::PackedSeqVec, + pub invalid_mask: Vec, + pub positions: Vec, + pub sk_positions: Vec, + pub minimizer_values: Vec, } impl FilterProcessor { @@ -270,6 +362,7 @@ impl FilterProcessor { local_buffer2: Vec::with_capacity(DEFAULT_BUFFER_SIZE), local_stats: ProcessingStats::default(), filter_buffers: FilterBuffers::default(), + records: RecordBuffer::new(), global_writer: Arc::new(Mutex::new(writer)), global_writer2: writer2.map(|w| Arc::new(Mutex::new(w))), global_stats: Arc::new(Mutex::new(ProcessingStats::default())), @@ -278,37 +371,25 @@ impl FilterProcessor { } } - fn should_keep_sequence(&mut self, seq: &[u8]) -> (bool, usize, usize, Vec) { - if seq.len() < self.kmer_length as usize { - return (self.deplete, 0, 0, Vec::new()); // If too short, keep if in deplete mode - } - - // Apply prefix length limit if specified - let effective_seq = if self.prefix_length > 0 && seq.len() > self.prefix_length { - &seq[..self.prefix_length] - } else { - seq - }; - - // Trim the last newline character from `effective_seq` if it has one. - let effective_seq = effective_seq.strip_suffix(b"\n").unwrap_or(effective_seq); - + #[inline] + fn compute_minimizer_positions(&mut self, seq: &[u8]) { let FilterBuffers { packed_seq, invalid_mask, positions, - minimizer_values, + sk_positions, + minimizer_values: _, } = &mut self.filter_buffers; packed_seq.clear(); - minimizer_values.clear(); positions.clear(); + sk_positions.clear(); invalid_mask.clear(); // Pack the sequence into 2-bit representation. // Any non-ACGT characters are silently converted to 2-bit ACGT as well. - packed_seq.push_ascii(effective_seq); - // let packed_seq = packed_seq::PackedSeqVec::from_ascii(effective_seq); + packed_seq.push_ascii(seq); + // let packed_seq = packed_seq::PackedSeqVec::from_ascii(seq); // TODO: Extract this to some nicer helper function in packed_seq? // TODO: Use SIMD? @@ -316,12 +397,9 @@ impl FilterProcessor { // +2: one to round up, and one buffer. invalid_mask.resize(packed_seq.len() / 64 + 2, 0); // let mut invalid_mask = vec![0u64; packed_seq.len() / 64 + 2]; - for i in (0..effective_seq.len()).step_by(64) { + for i in (0..seq.len()).step_by(64) { let mut mask = 0; - for (j, b) in effective_seq[i..(i + 64).min(effective_seq.len())] - .iter() - .enumerate() - { + for (j, b) in seq[i..(i + 64).min(seq.len())].iter().enumerate() { mask |= ((!matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't')) as u64) << j; @@ -330,13 +408,22 @@ impl FilterProcessor { invalid_mask[i / 64] = mask; } - // let mut positions = Vec::new(); - simd_minimizers::canonical_minimizer_positions( - packed_seq.as_slice(), - self.kmer_length as usize, - self.window_size as usize, - positions, - ); + if self.records.id_seq_ends.len() == 1 { + simd_minimizers::canonical_minimizer_positions( + packed_seq.as_slice(), + self.kmer_length as usize, + self.window_size as usize, + positions, + ); + } else { + simd_minimizers::canonical_minimizer_and_superkmer_positions( + packed_seq.as_slice(), + self.kmer_length as usize, + self.window_size as usize, + positions, + sk_positions, + ); + } assert!( self.kmer_length <= 56, @@ -358,6 +445,18 @@ impl FilterProcessor { (unsafe { invalid_mask.as_ptr().byte_add(byte).read_unaligned() } >> offset) & mask; x == 0 }); + } + + #[inline] + fn should_keep_given_mini_pos_range( + &mut self, + seq: &[u8], + mini_pos_range: Range, + ) -> (bool, usize, usize, Vec) { + let packed_seq = &self.filter_buffers.packed_seq; + let positions = &self.filter_buffers.positions[mini_pos_range]; + let minimizer_values = &mut self.filter_buffers.minimizer_values; + minimizer_values.clear(); // Get hash values for valid positions minimizer_values.extend( @@ -382,7 +481,7 @@ impl FilterProcessor { // Extract the k-mer sequence at this position if self.debug && i < positions.len() { let pos = positions[i] as usize; - let kmer = &effective_seq[pos..pos + self.kmer_length as usize]; + let kmer = &seq[pos..pos + self.kmer_length as usize]; hit_kmers.push(String::from_utf8_lossy(kmer).to_string()); } } @@ -396,6 +495,27 @@ impl FilterProcessor { ) } + fn should_keep_sequence(&mut self, seq: &[u8]) -> (bool, usize, usize, Vec) { + if seq.len() < self.kmer_length as usize { + return (self.deplete, 0, 0, Vec::new()); // If too short, keep if in deplete mode + } + + // Apply prefix length limit if specified + let effective_seq = if self.prefix_length > 0 && seq.len() > self.prefix_length { + &seq[..self.prefix_length] + } else { + seq + }; + + // Trim the last newline character from `effective_seq` if it has one. + let effective_seq = effective_seq.strip_suffix(b"\n").unwrap_or(effective_seq); + + self.compute_minimizer_positions(effective_seq); + + let mini_pos_range = 0..self.filter_buffers.positions.len(); + self.should_keep_given_mini_pos_range(effective_seq, mini_pos_range) + } + fn get_minimizer_hashes_and_positions(&self, seq: &[u8]) -> (Vec, Vec) { // Canonicalize sequence let canonical_seq = seq @@ -562,40 +682,85 @@ impl FilterProcessor { )); } } -} -impl ParallelProcessor for FilterProcessor { - fn process_record(&mut self, record: Rf) -> paraseq::parallel::Result<()> { - let seq = record.seq(); - self.local_stats.total_seqs += 1; - self.local_stats.total_bp += seq.len() as u64; + fn process_record_buffer(&mut self) -> paraseq::parallel::Result<()> { + self.local_stats.total_seqs += self.records.id_seq_ends.len() as u64; + self.local_stats.total_bp += self.records.seq_buffer.len() as u64; - let (should_keep, hit_count, total_minimizers, hit_kmers) = self.should_keep_sequence(&seq); + let mut records = RecordBuffer::default(); + core::mem::swap(&mut records, &mut self.records); - // Show debug info for sequences with hits - if self.debug { - eprintln!( - "DEBUG: {} hits={}/{} keep={} kmers=[{}]", - String::from_utf8_lossy(record.id()), - hit_count, - total_minimizers, - should_keep, - hit_kmers.join(",") - ); - } + self.compute_minimizer_positions(&records.seq_buffer); - if should_keep { - self.local_stats.output_bp += seq.len() as u64; - self.write_record(&record, &seq)?; + let mut mini_pos_ranges = Vec::with_capacity(records.id_seq_ends.len()); + if records.id_seq_ends.len() == 1 { + mini_pos_ranges.push(0..self.filter_buffers.positions.len()); } else { - self.local_stats.filtered_seqs += 1; - self.local_stats.filtered_bp += seq.len() as u64; + let mut mini_idx = 0; + for seq_end in records.id_seq_ends.iter().map(|&(_, s)| s as u32) { + let mini_start = mini_idx; + let last_kmer_start = seq_end - self.kmer_length as u32 + 1; + while mini_idx < self.filter_buffers.sk_positions.len() + && self.filter_buffers.sk_positions[mini_idx] < last_kmer_start + { + mini_idx += 1; + } + let mini_end = mini_idx; + while mini_idx + 1 < self.filter_buffers.sk_positions.len() + && self.filter_buffers.sk_positions[mini_idx + 1] < seq_end + { + mini_idx += 1; + } + mini_pos_ranges.push(mini_start..mini_end); + } + }; + + for (record, mini_pos_range) in records.iter().zip(mini_pos_ranges.into_iter()) { + let (should_keep, hit_count, total_minimizers, hit_kmers) = + self.should_keep_given_mini_pos_range(&records.seq_buffer, mini_pos_range); + + // Show debug info for sequences with hits + if self.debug { + eprintln!( + "DEBUG: {} hits={}/{} keep={} kmers=[{}]", + String::from_utf8_lossy(record.id()), + hit_count, + total_minimizers, + should_keep, + hit_kmers.join(",") + ); + } + + if should_keep { + self.local_stats.output_bp += record.seq.len() as u64; + self.write_record(&record, record.seq)?; + } else { + self.local_stats.filtered_seqs += 1; + self.local_stats.filtered_bp += record.seq.len() as u64; + } } + records.clear(); + core::mem::swap(&mut records, &mut self.records); + Ok(()) } +} + +impl ParallelProcessor for FilterProcessor { + fn process_record(&mut self, record: Rf) -> paraseq::parallel::Result<()> { + self.records.push_record(&record); + if self.records.seq_buffer.len() < SEQ_LEN_THRESHOLD { + return Ok(()); + } + self.process_record_buffer() + } fn on_batch_complete(&mut self) -> paraseq::parallel::Result<()> { + if !self.records.seq_buffer.is_empty() { + self.process_record_buffer()?; + } + // Write buffer to output if !self.local_buffer.is_empty() { let mut global_writer = self.global_writer.lock(); From ab318c03af7d821509b65103084e3ed64027803c Mon Sep 17 00:00:00 2001 From: Igor Martayan Date: Mon, 18 Aug 2025 18:27:33 +0200 Subject: [PATCH 2/2] Fix MinimalRecord id range --- src/filter.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/filter.rs b/src/filter.rs index 74ac30a..413d1f7 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -294,7 +294,7 @@ impl RecordBuffer { self.id_seq_ends.iter().map(move |&(id_end, seq_end)| { let rec = MinimalRecord { - id: &self.id_buffer[id_start..id_start], + id: &self.id_buffer[id_start..id_end], seq: &self.seq_buffer[seq_start..seq_end], qual: self .qual_buffer