diff --git a/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-macosx_10_12_x86_64.whl b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-macosx_10_12_x86_64.whl new file mode 100644 index 0000000..eb3751b Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-macosx_10_12_x86_64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-macosx_11_0_arm64.whl b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-macosx_11_0_arm64.whl new file mode 100644 index 0000000..d7c6354 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-macosx_11_0_arm64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl new file mode 100644 index 0000000..c5dae76 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl new file mode 100644 index 0000000..71427f8 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-win_amd64.whl b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-win_amd64.whl new file mode 100644 index 0000000..e5411ae Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp310-cp310-win_amd64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-macosx_10_12_x86_64.whl b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-macosx_10_12_x86_64.whl new file mode 100644 index 0000000..fdf11c7 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-macosx_10_12_x86_64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-macosx_11_0_arm64.whl b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-macosx_11_0_arm64.whl new file mode 100644 index 0000000..b27e754 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-macosx_11_0_arm64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl new file mode 100644 index 0000000..9e002b2 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl new file mode 100644 index 0000000..879e62c Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-win_amd64.whl b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-win_amd64.whl new file mode 100644 index 0000000..89211e6 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp311-cp311-win_amd64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-macosx_10_13_x86_64.whl b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-macosx_10_13_x86_64.whl new file mode 100644 index 0000000..fbd2b33 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-macosx_10_13_x86_64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-macosx_11_0_arm64.whl b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-macosx_11_0_arm64.whl new file mode 100644 index 0000000..238ee8e Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-macosx_11_0_arm64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl new file mode 100644 index 0000000..a549fed Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl new file mode 100644 index 0000000..f24f594 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-win_amd64.whl b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-win_amd64.whl new file mode 100644 index 0000000..0efd76d Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp312-cp312-win_amd64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-macosx_10_13_x86_64.whl b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-macosx_10_13_x86_64.whl new file mode 100644 index 0000000..4b70dd0 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-macosx_10_13_x86_64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-macosx_11_0_arm64.whl b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-macosx_11_0_arm64.whl new file mode 100644 index 0000000..ec37bcf Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-macosx_11_0_arm64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl new file mode 100644 index 0000000..e8b3302 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl new file mode 100644 index 0000000..cda7d8f Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-win_amd64.whl b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-win_amd64.whl new file mode 100644 index 0000000..f2824a6 Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0-cp313-cp313-win_amd64.whl differ diff --git a/for_pypi_upload/seq_smith-0.1.0.tar.gz b/for_pypi_upload/seq_smith-0.1.0.tar.gz new file mode 100644 index 0000000..1d813ce Binary files /dev/null and b/for_pypi_upload/seq_smith-0.1.0.tar.gz differ diff --git a/seq_smith/__init__.py b/seq_smith/__init__.py index 77c89c4..1013c98 100644 --- a/seq_smith/__init__.py +++ b/seq_smith/__init__.py @@ -10,6 +10,8 @@ local_global_align_many, overlap_align, overlap_align_many, + top_k_ungapped_local_align, + top_k_ungapped_local_align_many, ) from .python_utils import decode, encode, format_alignment_ascii, generate_cigar, make_score_matrix @@ -30,4 +32,6 @@ "make_score_matrix", "overlap_align", "overlap_align_many", + "top_k_ungapped_local_align", + "top_k_ungapped_local_align_many", ] diff --git a/seq_smith/_seq_smith.pyi b/seq_smith/_seq_smith.pyi index 2e78a35..64cd598 100644 --- a/seq_smith/_seq_smith.pyi +++ b/seq_smith/_seq_smith.pyi @@ -117,3 +117,20 @@ def overlap_align_many( gap_extend: int, num_threads: int | None = None, ) -> list[Alignment]: ... +def top_k_ungapped_local_align( + seqa: bytes, + seqb: bytes, + score_matrix: npt.NDArray[np.int32], + k: int, + filter_overlap_a: bool = True, + filter_overlap_b: bool = True, +) -> list[Alignment]: ... +def top_k_ungapped_local_align_many( + seqa: bytes, + seqbs: Sequence[bytes], + score_matrix: npt.NDArray[np.int32], + k: int, + num_threads: int | None = None, + filter_overlap_a: bool = True, + filter_overlap_b: bool = True, +) -> list[list[Alignment]]: ... diff --git a/src/lib.rs b/src/lib.rs index a1cc943..349d584 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,8 @@ use pyo3::types::PyBytes; use pyo3::wrap_pyfunction; use pyo3_stub_gen::{define_stub_info_gatherer, derive::*}; use rayon::prelude::*; +use std::cmp::Ordering; +use std::collections::BinaryHeap; /// Represents the type of an alignment fragment. #[gen_stub_pyclass_enum] @@ -114,8 +116,6 @@ struct Alignment { struct AlignmentParams<'a> { sa: &'a Vec, sb: &'a Vec, - sa_len: usize, - sb_len: usize, score_matrix: &'a Array2, gap_open: i32, gap_extend: i32, @@ -156,8 +156,6 @@ impl<'a> AlignmentParams<'a> { )); } Ok(Self { - sa_len: seqa.len(), - sb_len: seqb.len(), sa: seqa, sb: seqb, score_matrix, @@ -182,6 +180,43 @@ impl<'a> AlignmentParams<'a> { } } +struct UngappedAlignmentParams<'a> { + sa: &'a Vec, + sb: &'a Vec, + score_matrix: &'a Array2, +} + +impl<'a> UngappedAlignmentParams<'a> { + fn new(seqa: &'a Vec, seqb: &'a Vec, score_matrix: &'a Array2) -> PyResult { + if seqa.is_empty() || seqb.is_empty() { + return Err(PyErr::new::( + "Input sequences cannot be empty.", + )); + } + if score_matrix.ndim() != 2 { + return Err(PyErr::new::( + "Score matrix must be 2-dimensional.", + )); + } + let (rows, cols) = score_matrix.dim(); + if rows != cols { + return Err(PyErr::new::( + "Score matrix must be square.", + )); + } + Ok(Self { + sa: seqa, + sb: seqb, + score_matrix, + }) + } + + #[inline(always)] + fn match_score(&self, row: usize, col: usize) -> i32 { + self.score_matrix[[self.sa[col] as usize, self.sb[row] as usize]] + } +} + struct AlignmentData { curr_score: Array1, prev_score: Array1, @@ -196,11 +231,11 @@ impl AlignmentData { fn new(params: &AlignmentParams) -> Self { unsafe { Self { - curr_score: Array1::uninit(params.sb_len).assume_init(), - prev_score: Array1::uninit(params.sb_len).assume_init(), - dir_matrix: Array2::uninit((params.sa_len, params.sb_len)).assume_init(), - hgap_pos: Array1::uninit(params.sb_len).assume_init(), - hgap_score: Array1::uninit(params.sb_len).assume_init(), + curr_score: Array1::uninit(params.sb.len()).assume_init(), + prev_score: Array1::uninit(params.sb.len()).assume_init(), + dir_matrix: Array2::uninit((params.sa.len(), params.sb.len())).assume_init(), + hgap_pos: Array1::uninit(params.sb.len()).assume_init(), + hgap_score: Array1::uninit(params.sb.len()).assume_init(), vgap_pos: -1, vgap_score: 0, } @@ -377,7 +412,7 @@ fn traceback( if residue_a == residue_b { stats.num_exact_matches += 1; } else { - let score = params.score_matrix[[residue_a as usize, residue_b as usize]]; + let score = params.match_score(residue_a as usize, residue_b as usize); if score > 0 { stats.num_positive_mismatches += 1; } else { @@ -446,13 +481,13 @@ fn _local_align_core(params: AlignmentParams) -> PyResult { } }; - for row in 0..params.sb_len { + for row in 0..params.sb.len() { data.hgap_pos[row] = -1; data.hgap_score[row] = params.gap_open; data.prev_score[row] = 0; } - for col in 0..params.sa_len { + for col in 0..params.sa.len() { data.vgap_pos = -1; data.vgap_score = params.gap_open; @@ -461,7 +496,7 @@ fn _local_align_core(params: AlignmentParams) -> PyResult { data.write_cell(0, col, score, dir); data.update_gaps(0, col, score, ¶ms); - for row in 1..params.sb_len { + for row in 1..params.sb.len() { let match_score = data.prev_score[row - 1].saturating_add(params.match_score(row, col)); let (score, dir) = data.compute_cell_clipped(row, col, match_score); update_max_score(score, row, col); @@ -512,13 +547,7 @@ fn local_align<'py>( let score_matrix = score_matrix.as_array().into_owned(); py.detach(move || { - let params = AlignmentParams::new( - &seqa, - &seqb, - &score_matrix, - gap_open, - gap_extend, - )?; + let params = AlignmentParams::new(&seqa, &seqb, &score_matrix, gap_open, gap_extend)?; _local_align_core(params) }) } @@ -556,14 +585,14 @@ fn local_align_many<'py>( fn _global_align_core(params: AlignmentParams) -> PyResult { let mut data = AlignmentData::new(¶ms); - for row in 0..params.sb_len { + for row in 0..params.sb.len() { let score = params.gap_cost(row as i32 + 1); data.prev_score[row] = score; data.hgap_pos[row] = -1; data.hgap_score[row] = score.saturating_add(params.gap_open); } - for col in 0..params.sa_len { + for col in 0..params.sa.len() { data.vgap_pos = -1; data.vgap_score = params .gap_cost(col as i32 + 1) @@ -576,7 +605,7 @@ fn _global_align_core(params: AlignmentParams) -> PyResult { let (score, _) = data.compute_and_write_cell(0, col, match_score); data.update_gaps(0, col, score, ¶ms); - for row in 1..params.sb_len { + for row in 1..params.sb.len() { let match_score = data.prev_score[row - 1].saturating_add(params.match_score(row, col)); let (score, _) = data.compute_and_write_cell(row, col, match_score); data.update_gaps(row, col, score, ¶ms); @@ -584,9 +613,15 @@ fn _global_align_core(params: AlignmentParams) -> PyResult { data.swap_scores(); } - let final_score = data.prev_score[params.sb_len - 1]; - let (fragments, stats) = - traceback(&data, ¶ms, params.sa_len - 1, params.sb_len - 1, true, true); + let final_score = data.prev_score[params.sb.len() - 1]; + let (fragments, stats) = traceback( + &data, + ¶ms, + params.sa.len() - 1, + params.sb.len() - 1, + true, + true, + ); Ok(Alignment { fragments: fragments, @@ -627,13 +662,7 @@ fn global_align<'py>( let score_matrix = score_matrix.as_array().into_owned(); py.detach(move || { - let params = AlignmentParams::new( - &seqa, - &seqb, - &score_matrix, - gap_open, - gap_extend, - )?; + let params = AlignmentParams::new(&seqa, &seqb, &score_matrix, gap_open, gap_extend)?; _global_align_core(params) }) } @@ -675,30 +704,30 @@ fn _local_global_align_core(params: AlignmentParams) -> PyResult { let mut max_row = 0; let mut max_col = 0; - for row in 0..params.sb_len { + for row in 0..params.sb.len() { let score = params.gap_cost(row as i32 + 1); data.prev_score[row] = score; data.hgap_pos[row] = -1; data.hgap_score[row] = score.saturating_add(params.gap_open); } - for col in 0..params.sa_len { + for col in 0..params.sa.len() { data.vgap_pos = -1; data.vgap_score = params.gap_open; let (score, _) = data.compute_and_write_cell(0, col, params.match_score(0, col)); data.update_gaps(0, col, score, ¶ms); - for row in 1..params.sb_len { + for row in 1..params.sb.len() { let match_score = data.prev_score[row - 1].saturating_add(params.match_score(row, col)); let (score, _) = data.compute_and_write_cell(row, col, match_score); data.update_gaps(row, col, score, ¶ms); } - if data.curr_score[params.sb_len - 1] >= max_score { - max_row = params.sb_len - 1; + if data.curr_score[params.sb.len() - 1] >= max_score { + max_row = params.sb.len() - 1; max_col = col; - max_score = data.curr_score[params.sb_len - 1]; + max_score = data.curr_score[params.sb.len() - 1]; } data.swap_scores(); } @@ -745,13 +774,7 @@ fn local_global_align<'py>( let score_matrix = score_matrix.as_array().into_owned(); py.detach(move || { - let params = AlignmentParams::new( - &seqa, - &seqb, - &score_matrix, - gap_open, - gap_extend, - )?; + let params = AlignmentParams::new(&seqa, &seqb, &score_matrix, gap_open, gap_extend)?; _local_global_align_core(params) }) } @@ -804,13 +827,13 @@ fn _overlap_align_core(params: AlignmentParams) -> PyResult { } }; - for row in 0..params.sb_len { + for row in 0..params.sb.len() { data.prev_score[row] = 0; data.hgap_pos[row] = -1; data.hgap_score[row] = params.gap_open; } - for col in 0..params.sa_len { + for col in 0..params.sa.len() { data.vgap_pos = -1; data.vgap_score = params.gap_open; @@ -819,18 +842,22 @@ fn _overlap_align_core(params: AlignmentParams) -> PyResult { data.write_cell(0, col, score, dir); data.update_gaps(0, col, score, ¶ms); - for row in 1..params.sb_len { + for row in 1..params.sb.len() { let match_score = data.prev_score[row - 1].saturating_add(params.match_score(row, col)); let (score, dir) = data.compute_cell(row, col, match_score); data.write_cell(row, col, score, dir); data.update_gaps(row, col, score, ¶ms); } - update_max_score(data.curr_score[params.sb_len - 1], params.sb_len - 1, col); + update_max_score( + data.curr_score[params.sb.len() - 1], + params.sb.len() - 1, + col, + ); data.swap_scores(); } - for row in 0..params.sb_len { - update_max_score(data.prev_score[row], row, params.sa_len - 1); + for row in 0..params.sb.len() { + update_max_score(data.prev_score[row], row, params.sa.len() - 1); } if max_score == std::i32::MIN { @@ -840,7 +867,14 @@ fn _overlap_align_core(params: AlignmentParams) -> PyResult { stats: AlignmentStats::default(), }); } - let (fragments, stats) = traceback(&data, ¶ms, max_col as usize, max_row as usize, false, false); + let (fragments, stats) = traceback( + &data, + ¶ms, + max_col as usize, + max_row as usize, + false, + false, + ); Ok(Alignment { fragments: fragments, @@ -884,13 +918,7 @@ fn overlap_align<'py>( let score_matrix = score_matrix.as_array().into_owned(); py.detach(move || { - let params = AlignmentParams::new( - &seqa, - &seqb, - &score_matrix, - gap_open, - gap_extend, - )?; + let params = AlignmentParams::new(&seqa, &seqb, &score_matrix, gap_open, gap_extend)?; _overlap_align_core(params) }) } @@ -940,26 +968,278 @@ where let pool = rayon::ThreadPoolBuilder::new() .num_threads(num_threads.unwrap_or(0)) // 0 tells rayon to use a default number of threads .build() - .map_err(|e| PyErr::new::(format!("Failed to create thread pool: {}", e)))?; + .map_err(|e| { + PyErr::new::(format!( + "Failed to create thread pool: {}", + e + )) + })?; pool.install(|| { seqbs .into_par_iter() .map(|seqb| { - let params = AlignmentParams::new( - &seqa, - &seqb, - &score_matrix, - gap_open, - gap_extend, - )?; + let params = + AlignmentParams::new(&seqa, &seqb, &score_matrix, gap_open, gap_extend)?; align_func(params) }) .collect() }) } +#[derive(Eq, PartialEq)] +struct Candidate { + score: i32, + sa_start: usize, + sb_start: usize, + len: usize, +} + +impl Ord for Candidate { + fn cmp(&self, other: &Self) -> Ordering { + self.score + .cmp(&other.score) + .then_with(|| self.sa_start.cmp(&other.sa_start)) + .then_with(|| self.sb_start.cmp(&other.sb_start)) + } +} + +impl PartialOrd for Candidate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +fn _top_k_ungapped_local_align_core( + params: UngappedAlignmentParams, + k: usize, + filter_overlap_a: bool, + filter_overlap_b: bool, +) -> PyResult> { + let sa_len = params.sa.len(); + let sb_len = params.sb.len(); + + let mut candidates: BinaryHeap = BinaryHeap::new(); + + let mut add_candidate = |score: i32, sa_start: usize, sb_start: usize, len: usize| { + if score > 0 { + candidates.push(Candidate { + score, + sa_start, + sb_start, + len, + }); + } + }; + + let mut process_diagonal = |start_row: usize, start_col: usize, max_len: usize| { + let mut curr_score = 0; + let mut segment_start_idx = 0; // index along diagonal where current positive segment started + let mut peak_score = 0; + let mut peak_idx = 0; // index along diagonal where peak occurred + + for i in 0..max_len { + let row = start_row + i; + let col = start_col + i; + let val = params.match_score(row, col); + + if curr_score == 0 && val <= 0 { + continue; + } + if curr_score == 0 { + segment_start_idx = i; + } + + curr_score += val; + + if curr_score <= 0 { + add_candidate( + peak_score, + start_col + segment_start_idx, + start_row + segment_start_idx, + peak_idx - segment_start_idx + 1, + ); + curr_score = 0; + peak_score = 0; + } else { + if curr_score > peak_score { + peak_score = curr_score; + peak_idx = i; + } + } + } + add_candidate( + peak_score, + start_col + segment_start_idx, + start_row + segment_start_idx, + peak_idx - segment_start_idx + 1, + ); + }; + + // Diagonals starting at first row (row=0, col=0..sa_len) + for start_col in 0..sa_len { + let max_len = std::cmp::min(sa_len - start_col, sb_len); + process_diagonal(0, start_col, max_len); + } + + // Diagonals starting at first column (row=1..sb_len, col=0) + for start_row in 1..sb_len { + let max_len = std::cmp::min(sa_len, sb_len - start_row); + process_diagonal(start_row, 0, max_len); + } + + // Select top k non-overlapping + let mut alignments: Vec = Vec::with_capacity(k); + + while alignments.len() < k { + if let Some(candidate) = candidates.pop() { + // Check overlap + let sa_end = candidate.sa_start + candidate.len; + let sb_end = candidate.sb_start + candidate.len; + + let overlap = alignments.iter().any(|prev| { + let p_sa_start = (prev.fragments[0].sa_start - 1) as usize; // 0-indexed + let p_sb_start = (prev.fragments[0].sb_start - 1) as usize; // 0-indexed + let p_sa_end = p_sa_start + prev.fragments[0].len as usize; + let p_sb_end = p_sb_start + prev.fragments[0].len as usize; + + // Overlap in A? + let overlaps_a = + filter_overlap_a && candidate.sa_start < p_sa_end && sa_end > p_sa_start; + // Overlap in B? + let overlaps_b = + filter_overlap_b && candidate.sb_start < p_sb_end && sb_end > p_sb_start; + + overlaps_a || overlaps_b + }); + + if !overlap { + // Construct Alignment + let mut stats = AlignmentStats::default(); + for i in 0..candidate.len { + let r = candidate.sb_start + i; + let c = candidate.sa_start + i; + let val = params.match_score(c, r); + if params.sa[c] == params.sb[r] { + stats.num_exact_matches += 1; + } else if val > 0 { + stats.num_positive_mismatches += 1; + } else { + stats.num_negative_mismatches += 1; + } + } + + let frag = AlignmentFragment { + fragment_type: FragmentType::Match, + sa_start: (candidate.sa_start + 1) as i32, + sb_start: (candidate.sb_start + 1) as i32, + len: candidate.len as i32, + }; + + alignments.push(Alignment { + fragments: vec![frag], + score: candidate.score, + stats: stats, + }); + } + } else { + break; + } + } + + Ok(alignments) +} + +/// Finds the top-k non-overlapping ungapped local alignments (HSPs). +/// +/// Args: +/// seqa (bytes): The first sequence. +/// seqb (bytes): The second sequence. +/// score_matrix (numpy.ndarray): Scorin matrix. +/// k (int): Number of alignments to return. +/// +/// Returns: +/// list[Alignment]: List of top-k non-overlapping alignments. +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(signature = (seqa, seqb, score_matrix, k, filter_overlap_a=true, filter_overlap_b=true))] +fn top_k_ungapped_local_align<'py>( + py: Python<'py>, + seqa: &Bound<'py, PyBytes>, + seqb: &Bound<'py, PyBytes>, + score_matrix: PyReadonlyArray2, + k: usize, + filter_overlap_a: bool, + filter_overlap_b: bool, +) -> PyResult> { + let seqa = seqa.as_bytes().to_vec(); + let seqb = seqb.as_bytes().to_vec(); + let score_matrix = score_matrix.as_array().into_owned(); + + py.detach(move || { + _top_k_ungapped_local_align_core( + UngappedAlignmentParams::new(&seqa, &seqb, &score_matrix)?, + k, + filter_overlap_a, + filter_overlap_b, + ) + }) +} + +/// Finds the top-k non-overlapping ungapped local alignments (HSPs) against many sequences in parallel. +/// +/// Args: +/// seqa (bytes): The query sequence. +/// seqbs (list[bytes]): List of target sequences. +/// score_matrix (numpy.ndarray): Scoring matrix. +/// k (int): Number of alignments to return per target sequence. +/// num_threads (int, optional): Number of threads to use. Defaults to all available. +/// +/// Returns: +/// list[list[Alignment]]: List of alignment lists. +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(signature = (seqa, seqbs, score_matrix, k, num_threads=None, filter_overlap_a=true, filter_overlap_b=true))] +fn top_k_ungapped_local_align_many<'py>( + py: Python<'py>, + seqa: &Bound<'py, PyBytes>, + seqbs: Vec>, + score_matrix: PyReadonlyArray2, + k: usize, + num_threads: Option, + filter_overlap_a: bool, + filter_overlap_b: bool, +) -> PyResult>> { + let seqa = seqa.as_bytes().to_vec(); + let seqbs: Vec> = seqbs.iter().map(|s| s.as_bytes().to_vec()).collect(); + let score_matrix = score_matrix.as_array().into_owned(); + py.detach(move || { + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(num_threads.unwrap_or(0)) + .build() + .map_err(|e| { + PyErr::new::(format!( + "Failed to create thread pool: {}", + e + )) + })?; + + pool.install(|| { + seqbs + .into_par_iter() + .map(|seqb| { + _top_k_ungapped_local_align_core( + UngappedAlignmentParams::new(&seqa, &seqb, &score_matrix)?, + k, + filter_overlap_a, + filter_overlap_b, + ) + }) + .collect() + }) + }) +} #[pymodule] fn _seq_smith(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -971,6 +1251,8 @@ fn _seq_smith(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(global_align_many))?; m.add_wrapped(wrap_pyfunction!(local_global_align_many))?; m.add_wrapped(wrap_pyfunction!(overlap_align_many))?; + m.add_wrapped(wrap_pyfunction!(top_k_ungapped_local_align))?; + m.add_wrapped(wrap_pyfunction!(top_k_ungapped_local_align_many))?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/tests/test_seq_smith.py b/tests/test_seq_smith.py index b693855..00aa62d 100644 --- a/tests/test_seq_smith.py +++ b/tests/test_seq_smith.py @@ -13,6 +13,8 @@ local_global_align, make_score_matrix, overlap_align, + top_k_ungapped_local_align, + top_k_ungapped_local_align_many, ) @@ -476,3 +478,96 @@ def test_local_global_align_overhangs() -> None: assert aln.fragments == expected_fragments assert aln.score == -2 + + +def test_top_k_ungapped_simple() -> None: + # Use custom alphabet to support Z/W if needed, or just use ACGT + alphabet = "ACGT" + seqa = encode("AAAATTTTCCCC", alphabet) + seqb = encode("AAAAGGGGCCCC", alphabet) + + # matrix: match=2, mismatch=-5 + score_matrix = make_score_matrix(alphabet, match_score=2, mismatch_score=-5) + + # AAAA matches (4*2=8). TTTT vs GGGG (-5*4 = -20). CCCC matches (8). + # Alignment 1: AAAA (score 8) + # Alignment 2: CCCC (score 8) + + alignments = top_k_ungapped_local_align(seqa, seqb, score_matrix, k=5) + + # Should get 2 alignments + assert len(alignments) == 2 + # Verify scores + assert alignments[0].score == 8 + assert alignments[1].score == 8 + + starts = sorted([(a.fragments[0].sa_start, a.fragments[0].sb_start) for a in alignments]) + assert starts[0] == (1, 1) # 1-based index in fragments + assert starts[1] == (9, 9) + + +def test_top_k_ungapped_overlapping_candidates(common_data: AlignmentData) -> None: + # Case where second best candidate overlaps best + # Sequence A: AAAAA + # Sequence B: AAAAA + # match=1 (from common_data observation) + + seqa = encode("AAAAA", common_data.alphabet) + seqb = encode("AAAAA", common_data.alphabet) + + # ensure score matrix is what we think (match=1) or make our own + score_matrix = make_score_matrix(common_data.alphabet, match_score=2, mismatch_score=-1) + + alignments = top_k_ungapped_local_align(seqa, seqb, score_matrix, k=2) + + assert len(alignments) == 1 + assert alignments[0].score == 10 # 5 * 2 + assert alignments[0].fragments[0].len == 5 + + +def test_top_k_ungapped_limit() -> None: + # A: AA..CC..GG + # B: AA..CC..GG + alphabet = "ACGT" + # Use T vs G for mismatch + seqa = encode("AATTCCTTGG", alphabet) + seqb = encode("AAGGCCGGGG", alphabet) + + score_matrix = make_score_matrix(alphabet, match_score=2, mismatch_score=-5) + + # Expected HSPs: AA (4), mismatch, CC (4), mismatch, GG (4) + + alignments = top_k_ungapped_local_align(seqa, seqb, score_matrix, k=2) + + assert len(alignments) == 2 + assert alignments[0].score == 4 + assert alignments[1].score == 4 + + +def test_top_k_ungapped_many_simple() -> None: + # Sequence A: AAAA + # Sequence B1: AAAA (perfect) + # Sequence B2: CCCC (mismatch) + alphabet = "ACGT" + seqa = encode("AAAA", alphabet) + + seqb1 = encode("AAAA", alphabet) + seqb2 = encode("CCCC", alphabet) + + score_matrix = make_score_matrix(alphabet, match_score=2, mismatch_score=-5) + + alignments_list = top_k_ungapped_local_align_many(seqa, [seqb1, seqb2], score_matrix, k=1) + + assert len(alignments_list) == 2 + + # Check first alignment (AAAA vs AAAA) + # Should have score 8 + assert len(alignments_list[0]) == 1 + assert alignments_list[0][0].score == 8 + + # Check second alignment (AAAA vs CCCC) + # Should have no positive candidates if mismatch penalty is high enough? + # -5 * 4 = -20. + # So should be empty if score <= 0. + # Our implementation returns empty if no positive peaks. + assert len(alignments_list[1]) == 0