Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 224 additions & 59 deletions src/filter.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{FilterConfig, index::load_minimizer_hashes};
use anyhow::{Context, Result};
use core::ops::Range;
use flate2::write::GzEncoder;
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
use liblzma::write::XzEncoder;
Expand All @@ -23,6 +24,10 @@ use zstd::stream::write::Encoder as ZstdEncoder;

const OUTPUT_BUFFER_SIZE: usize = 8 * 1024 * 1024; // Opt: 8MB output buffer
const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
const SEQ_LEN_THRESHOLD: usize = 8000; // Seq batch size (in bp) for simd-minimizers
const DEFAULT_SEQ_BUFFER_SIZE: usize = 30000;
const DEFAULT_NUM_RECORDS_PER_BATCH: usize = SEQ_LEN_THRESHOLD / 100; // Estimated number of records per batch (assuming 100bp short reads)
const DEFAULT_ID_BUFFER_SIZE: usize = DEFAULT_NUM_RECORDS_PER_BATCH * 80;

type BoxedWriter = Box<dyn Write + Send>;

Expand Down Expand Up @@ -198,6 +203,7 @@ struct FilterProcessor {
local_buffer2: Vec<u8>, // Second buffer for paired output
local_stats: ProcessingStats,
filter_buffers: FilterBuffers,
records: RecordBuffer,

// Global state
global_writer: Arc<Mutex<BoxedWriter>>,
Expand All @@ -217,12 +223,98 @@ struct ProcessingStats {
output_seq_counter: u64,
}

#[derive(Clone)]
pub struct MinimalRecord<'a> {
id: &'a [u8],
seq: &'a [u8],
qual: Option<&'a [u8]>,
}

impl Record for MinimalRecord<'_> {
fn id(&self) -> &[u8] {
self.id
}

fn seq(&self) -> std::borrow::Cow<[u8]> {
std::borrow::Cow::Borrowed(self.seq)
}

fn seq_raw(&self) -> &[u8] {
self.seq
}

fn qual(&self) -> Option<&[u8]> {
self.qual
}
}

#[derive(Default, Clone)]
pub struct RecordBuffer {
pub id_buffer: Vec<u8>,
pub seq_buffer: Vec<u8>,
pub qual_buffer: Option<Vec<u8>>,
pub id_seq_ends: Vec<(usize, usize)>,
}

impl RecordBuffer {
pub fn new() -> Self {
Self {
id_buffer: Vec::with_capacity(DEFAULT_ID_BUFFER_SIZE),
seq_buffer: Vec::with_capacity(DEFAULT_SEQ_BUFFER_SIZE),
id_seq_ends: Vec::with_capacity(DEFAULT_NUM_RECORDS_PER_BATCH),
qual_buffer: None,
}
}

pub fn push_record<Rf: Record>(&mut self, record: &Rf) {
self.id_buffer.extend_from_slice(record.id());
self.seq_buffer.extend_from_slice(&record.seq());
self.id_seq_ends
.push((self.id_buffer.len(), self.seq_buffer.len()));
if let Some(qual) = record.qual() {
if self.qual_buffer.is_none() {
self.qual_buffer = Some(Vec::with_capacity(DEFAULT_BUFFER_SIZE));
}
self.qual_buffer.as_mut().unwrap().extend_from_slice(qual);
}
}

pub fn clear(&mut self) {
self.id_buffer.clear();
self.seq_buffer.clear();
self.id_seq_ends.clear();
if let Some(qual_buffer) = self.qual_buffer.as_mut() {
qual_buffer.clear();
}
}

pub fn iter(&self) -> impl ExactSizeIterator<Item = MinimalRecord> {
let mut id_start = 0;
let mut seq_start = 0;

self.id_seq_ends.iter().map(move |&(id_end, seq_end)| {
let rec = MinimalRecord {
id: &self.id_buffer[id_start..id_end],
seq: &self.seq_buffer[seq_start..seq_end],
qual: self
.qual_buffer
.as_ref()
.map(|qual_buffer| &qual_buffer[seq_start..seq_end]),
};
id_start = id_end;
seq_start = seq_end;
rec
})
}
}

#[derive(Default, Clone)]
struct FilterBuffers {
packed_seq: packed_seq::PackedSeqVec,
invalid_mask: Vec<u64>,
positions: Vec<u32>,
minimizer_values: Vec<u64>,
pub packed_seq: packed_seq::PackedSeqVec,
pub invalid_mask: Vec<u64>,
pub positions: Vec<u32>,
pub sk_positions: Vec<u32>,
pub minimizer_values: Vec<u64>,
}

impl FilterProcessor {
Expand Down Expand Up @@ -270,6 +362,7 @@ impl FilterProcessor {
local_buffer2: Vec::with_capacity(DEFAULT_BUFFER_SIZE),
local_stats: ProcessingStats::default(),
filter_buffers: FilterBuffers::default(),
records: RecordBuffer::new(),
global_writer: Arc::new(Mutex::new(writer)),
global_writer2: writer2.map(|w| Arc::new(Mutex::new(w))),
global_stats: Arc::new(Mutex::new(ProcessingStats::default())),
Expand All @@ -278,50 +371,35 @@ impl FilterProcessor {
}
}

fn should_keep_sequence(&mut self, seq: &[u8]) -> (bool, usize, usize, Vec<String>) {
if seq.len() < self.kmer_length as usize {
return (self.deplete, 0, 0, Vec::new()); // If too short, keep if in deplete mode
}

// Apply prefix length limit if specified
let effective_seq = if self.prefix_length > 0 && seq.len() > self.prefix_length {
&seq[..self.prefix_length]
} else {
seq
};

// Trim the last newline character from `effective_seq` if it has one.
let effective_seq = effective_seq.strip_suffix(b"\n").unwrap_or(effective_seq);

#[inline]
fn compute_minimizer_positions(&mut self, seq: &[u8]) {
let FilterBuffers {
packed_seq,
invalid_mask,
positions,
minimizer_values,
sk_positions,
minimizer_values: _,
} = &mut self.filter_buffers;

packed_seq.clear();
minimizer_values.clear();
positions.clear();
sk_positions.clear();
invalid_mask.clear();

// Pack the sequence into 2-bit representation.
// Any non-ACGT characters are silently converted to 2-bit ACGT as well.
packed_seq.push_ascii(effective_seq);
// let packed_seq = packed_seq::PackedSeqVec::from_ascii(effective_seq);
packed_seq.push_ascii(seq);
// let packed_seq = packed_seq::PackedSeqVec::from_ascii(seq);

// TODO: Extract this to some nicer helper function in packed_seq?
// TODO: Use SIMD?
// TODO: Should probably add some test for this.
// +2: one to round up, and one buffer.
invalid_mask.resize(packed_seq.len() / 64 + 2, 0);
// let mut invalid_mask = vec![0u64; packed_seq.len() / 64 + 2];
for i in (0..effective_seq.len()).step_by(64) {
for i in (0..seq.len()).step_by(64) {
let mut mask = 0;
for (j, b) in effective_seq[i..(i + 64).min(effective_seq.len())]
.iter()
.enumerate()
{
for (j, b) in seq[i..(i + 64).min(seq.len())].iter().enumerate() {
mask |= ((!matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't'))
as u64)
<< j;
Expand All @@ -330,13 +408,22 @@ impl FilterProcessor {
invalid_mask[i / 64] = mask;
}

// let mut positions = Vec::new();
simd_minimizers::canonical_minimizer_positions(
packed_seq.as_slice(),
self.kmer_length as usize,
self.window_size as usize,
positions,
);
if self.records.id_seq_ends.len() == 1 {
simd_minimizers::canonical_minimizer_positions(
packed_seq.as_slice(),
self.kmer_length as usize,
self.window_size as usize,
positions,
);
} else {
simd_minimizers::canonical_minimizer_and_superkmer_positions(
packed_seq.as_slice(),
self.kmer_length as usize,
self.window_size as usize,
positions,
sk_positions,
);
}

assert!(
self.kmer_length <= 56,
Expand All @@ -358,6 +445,18 @@ impl FilterProcessor {
(unsafe { invalid_mask.as_ptr().byte_add(byte).read_unaligned() } >> offset) & mask;
x == 0
});
}

#[inline]
fn should_keep_given_mini_pos_range(
&mut self,
seq: &[u8],
mini_pos_range: Range<usize>,
) -> (bool, usize, usize, Vec<String>) {
let packed_seq = &self.filter_buffers.packed_seq;
let positions = &self.filter_buffers.positions[mini_pos_range];
let minimizer_values = &mut self.filter_buffers.minimizer_values;
minimizer_values.clear();

// Get hash values for valid positions
minimizer_values.extend(
Expand All @@ -382,7 +481,7 @@ impl FilterProcessor {
// Extract the k-mer sequence at this position
if self.debug && i < positions.len() {
let pos = positions[i] as usize;
let kmer = &effective_seq[pos..pos + self.kmer_length as usize];
let kmer = &seq[pos..pos + self.kmer_length as usize];
hit_kmers.push(String::from_utf8_lossy(kmer).to_string());
}
}
Expand All @@ -396,6 +495,27 @@ impl FilterProcessor {
)
}

fn should_keep_sequence(&mut self, seq: &[u8]) -> (bool, usize, usize, Vec<String>) {
if seq.len() < self.kmer_length as usize {
return (self.deplete, 0, 0, Vec::new()); // If too short, keep if in deplete mode
}

// Apply prefix length limit if specified
let effective_seq = if self.prefix_length > 0 && seq.len() > self.prefix_length {
&seq[..self.prefix_length]
} else {
seq
};

// Trim the last newline character from `effective_seq` if it has one.
let effective_seq = effective_seq.strip_suffix(b"\n").unwrap_or(effective_seq);

self.compute_minimizer_positions(effective_seq);

let mini_pos_range = 0..self.filter_buffers.positions.len();
self.should_keep_given_mini_pos_range(effective_seq, mini_pos_range)
}

fn get_minimizer_hashes_and_positions(&self, seq: &[u8]) -> (Vec<u64>, Vec<u32>) {
// Canonicalize sequence
let canonical_seq = seq
Expand Down Expand Up @@ -562,40 +682,85 @@ impl FilterProcessor {
));
}
}
}

impl ParallelProcessor for FilterProcessor {
fn process_record<Rf: Record>(&mut self, record: Rf) -> paraseq::parallel::Result<()> {
let seq = record.seq();
self.local_stats.total_seqs += 1;
self.local_stats.total_bp += seq.len() as u64;
fn process_record_buffer(&mut self) -> paraseq::parallel::Result<()> {
self.local_stats.total_seqs += self.records.id_seq_ends.len() as u64;
self.local_stats.total_bp += self.records.seq_buffer.len() as u64;

let (should_keep, hit_count, total_minimizers, hit_kmers) = self.should_keep_sequence(&seq);
let mut records = RecordBuffer::default();
core::mem::swap(&mut records, &mut self.records);

// Show debug info for sequences with hits
if self.debug {
eprintln!(
"DEBUG: {} hits={}/{} keep={} kmers=[{}]",
String::from_utf8_lossy(record.id()),
hit_count,
total_minimizers,
should_keep,
hit_kmers.join(",")
);
}
self.compute_minimizer_positions(&records.seq_buffer);

if should_keep {
self.local_stats.output_bp += seq.len() as u64;
self.write_record(&record, &seq)?;
let mut mini_pos_ranges = Vec::with_capacity(records.id_seq_ends.len());
if records.id_seq_ends.len() == 1 {
mini_pos_ranges.push(0..self.filter_buffers.positions.len());
} else {
self.local_stats.filtered_seqs += 1;
self.local_stats.filtered_bp += seq.len() as u64;
let mut mini_idx = 0;
for seq_end in records.id_seq_ends.iter().map(|&(_, s)| s as u32) {
let mini_start = mini_idx;
let last_kmer_start = seq_end - self.kmer_length as u32 + 1;
while mini_idx < self.filter_buffers.sk_positions.len()
&& self.filter_buffers.sk_positions[mini_idx] < last_kmer_start
{
mini_idx += 1;
}
let mini_end = mini_idx;
while mini_idx + 1 < self.filter_buffers.sk_positions.len()
&& self.filter_buffers.sk_positions[mini_idx + 1] < seq_end
{
mini_idx += 1;
}
mini_pos_ranges.push(mini_start..mini_end);
}
};

for (record, mini_pos_range) in records.iter().zip(mini_pos_ranges.into_iter()) {
let (should_keep, hit_count, total_minimizers, hit_kmers) =
self.should_keep_given_mini_pos_range(&records.seq_buffer, mini_pos_range);

// Show debug info for sequences with hits
if self.debug {
eprintln!(
"DEBUG: {} hits={}/{} keep={} kmers=[{}]",
String::from_utf8_lossy(record.id()),
hit_count,
total_minimizers,
should_keep,
hit_kmers.join(",")
);
}

if should_keep {
self.local_stats.output_bp += record.seq.len() as u64;
self.write_record(&record, record.seq)?;
} else {
self.local_stats.filtered_seqs += 1;
self.local_stats.filtered_bp += record.seq.len() as u64;
}
}

records.clear();
core::mem::swap(&mut records, &mut self.records);

Ok(())
}
}

impl ParallelProcessor for FilterProcessor {
fn process_record<Rf: Record>(&mut self, record: Rf) -> paraseq::parallel::Result<()> {
self.records.push_record(&record);
if self.records.seq_buffer.len() < SEQ_LEN_THRESHOLD {
return Ok(());
}
self.process_record_buffer()
}

fn on_batch_complete(&mut self) -> paraseq::parallel::Result<()> {
if !self.records.seq_buffer.is_empty() {
self.process_record_buffer()?;
}

// Write buffer to output
if !self.local_buffer.is_empty() {
let mut global_writer = self.global_writer.lock();
Expand Down