diff --git a/Makefile b/Makefile index f71e41ac0f..6e168e261d 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,8 @@ include/sourmash.h: src/core/src/lib.rs \ src/core/src/ffi/minhash.rs \ src/core/src/ffi/signature.rs \ src/core/src/ffi/nodegraph.rs \ + src/core/src/ffi/index/mod.rs \ + src/core/src/ffi/index/revindex.rs \ src/core/src/errors.rs cd src/core && \ RUSTC_BOOTSTRAP=1 cbindgen -c cbindgen.toml . -o ../../$@ diff --git a/include/sourmash.h b/include/sourmash.h index 788d483605..de422efcf5 100644 --- a/include/sourmash.h +++ b/include/sourmash.h @@ -52,6 +52,10 @@ typedef struct SourmashKmerMinHash SourmashKmerMinHash; typedef struct SourmashNodegraph SourmashNodegraph; +typedef struct SourmashRevIndex SourmashRevIndex; + +typedef struct SourmashSearchResult SourmashSearchResult; + typedef struct SourmashSignature SourmashSignature; /** @@ -302,6 +306,51 @@ SourmashNodegraph *nodegraph_with_tables(uintptr_t ksize, uintptr_t starting_size, uintptr_t n_tables); +void revindex_free(SourmashRevIndex *ptr); + +const SourmashSearchResult *const *revindex_gather(const SourmashRevIndex *ptr, + const SourmashSignature *sig_ptr, + double threshold, + bool _do_containment, + bool _ignore_abundance, + uintptr_t *size); + +uint64_t revindex_len(const SourmashRevIndex *ptr); + +SourmashRevIndex *revindex_new_with_paths(const SourmashStr *const *search_sigs_ptr, + uintptr_t insigs, + const SourmashKmerMinHash *template_ptr, + uintptr_t threshold, + const SourmashKmerMinHash *const *queries_ptr, + uintptr_t inqueries, + bool keep_sigs); + +SourmashRevIndex *revindex_new_with_sigs(const SourmashSignature *const *search_sigs_ptr, + uintptr_t insigs, + const SourmashKmerMinHash *template_ptr, + uintptr_t threshold, + const SourmashKmerMinHash *const *queries_ptr, + uintptr_t inqueries); + +uint64_t revindex_scaled(const SourmashRevIndex *ptr); + +const SourmashSearchResult *const *revindex_search(const SourmashRevIndex *ptr, + const SourmashSignature *sig_ptr, + double threshold, + bool do_containment, + bool _ignore_abundance, + uintptr_t *size); + +SourmashSignature **revindex_signatures(const SourmashRevIndex *ptr, uintptr_t *size); + +SourmashStr searchresult_filename(const SourmashSearchResult *ptr); + +void searchresult_free(SourmashSearchResult *ptr); + +double searchresult_score(const SourmashSearchResult *ptr); + +SourmashSignature *searchresult_signature(const SourmashSearchResult *ptr); + void signature_add_protein(SourmashSignature *ptr, const char *sequence); void signature_add_sequence(SourmashSignature *ptr, const char *sequence, bool force); diff --git a/setup.py b/setup.py index f416ff3336..08d7dbc8f8 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,10 @@ def build_native(spec): - cmd = ["cargo", "build", "--manifest-path", "src/core/Cargo.toml", "--lib"] + cmd = ["cargo", "build", + "--manifest-path", "src/core/Cargo.toml", + # "--features", "parallel", + "--lib"] target = "debug" if not DEBUG_BUILD: diff --git a/src/core/CHANGELOG.md b/src/core/CHANGELOG.md index 05eff4f83b..3915a62086 100644 --- a/src/core/CHANGELOG.md +++ b/src/core/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +Added: + +- An inverted index, codename Greyhound (#1238) + ## [0.11.0] - 2021-07-07 Added: diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 359bc26d4b..0efe6dc1c2 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -26,6 +26,7 @@ az = "1.0.0" bytecount = "0.6.0" byteorder = "1.4.3" cfg-if = "1.0" +counter = "0.5.2" finch = { version = "0.4.1", optional = true } fixedbitset = "0.4.0" getset = "0.1.1" @@ -42,6 +43,10 @@ serde_json = "1.0.53" primal-check = "0.3.1" thiserror = "1.0" typed-builder = "0.9.0" +twox-hash = "1.6.0" +vec-collections = "0.3.4" +piz = "0.4.0" +memmap2 = "0.5.0" [dev-dependencies] assert_matches = "1.3.0" diff --git a/src/core/cbindgen.toml b/src/core/cbindgen.toml index faeef5c41f..cd6cd781c2 100644 --- a/src/core/cbindgen.toml +++ b/src/core/cbindgen.toml @@ -8,6 +8,7 @@ clean = true [parse.expand] crates = ["sourmash"] +features = [] [enum] rename_variants = "QualifiedScreamingSnakeCase" diff --git a/src/core/src/encodings.rs b/src/core/src/encodings.rs index 5d1cc94fb2..9749e886e6 100644 --- a/src/core/src/encodings.rs +++ b/src/core/src/encodings.rs @@ -1,12 +1,26 @@ +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::convert::TryFrom; +use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; use std::iter::Iterator; use std::str; +use nohash_hasher::BuildNoHashHasher; use once_cell::sync::Lazy; use crate::Error; +// To consider there: use a slab allocator for IdxTracker +// https://twitter.com/tomaka17/status/1391052081272967170 +// Pro-tip: you might be able to save a lot of hashmap lookups +// if you replace a `HashMap` with a `HashMap` +// and a `Slab`. This might be very useful if K is something +// heavy such as a `String`. +pub type Color = u64; +pub type Idx = u64; +type IdxTracker = (vec_collections::VecSet<[Idx; 4]>, u64); +type ColorToIdx = HashMap>; + #[allow(non_camel_case_types)] #[derive(Debug, Clone, Copy, PartialEq)] #[repr(u32)] @@ -357,3 +371,202 @@ pub const VALID: [bool; 256] = { lookup[b'T' as usize] = true; lookup }; + +#[derive(Serialize, Deserialize, Default)] +pub struct Colors { + colors: ColorToIdx, +} + +impl Colors { + pub fn new() -> Colors { + Default::default() + } + + /// Given a color and a new idx, return an updated color + /// + /// This might create a new one, or find an already existing color + /// that contains the new_idx + /// + /// Future optimization: store a count for each color, so we can track + /// if there are extra colors that can be removed at the end. + /// (the count is decreased whenever a new color has to be created) + pub fn update<'a, I: IntoIterator>( + &mut self, + current_color: Option, + new_idxs: I, + ) -> Result { + if let Some(color) = current_color { + if let Some(idxs) = self.colors.get_mut(&color) { + let idx_to_add: Vec<_> = new_idxs + .into_iter() + .filter(|new_idx| !idxs.0.contains(new_idx)) + .collect(); + + if idx_to_add.is_empty() { + // Easy case, it already has all the new_idxs, so just return this color + idxs.1 += 1; + Ok(color) + } else { + // We need to either create a new color, + // or find an existing color that have the same idxs + + let mut idxs = idxs.clone(); + idxs.0.extend(idx_to_add.into_iter().cloned()); + let new_color = Colors::compute_color(&idxs); + + if new_color != color { + self.colors.get_mut(&color).unwrap().1 -= 1; + if self.colors[&color].1 == 0 { + self.colors.remove(&color); + }; + }; + + self.colors + .entry(new_color) + .and_modify(|old_idxs| { + assert_eq!(old_idxs.0, idxs.0); + old_idxs.1 += 1; + }) + .or_insert_with(|| (idxs.0, 1)); + Ok(new_color) + } + } else { + unimplemented!("throw error, current_color must exist in order to be updated. current_color: {:?}, colors: {:#?}", current_color, &self.colors); + } + } else { + let mut idxs = IdxTracker::default(); + idxs.0.extend(new_idxs.into_iter().cloned()); + idxs.1 = 1; + let new_color = Colors::compute_color(&idxs); + self.colors + .entry(new_color) + .and_modify(|old_idxs| { + assert_eq!(old_idxs.0, idxs.0); + old_idxs.1 += 1; + }) + .or_insert_with(|| (idxs.0, 1)); + Ok(new_color) + } + } + + fn compute_color(idxs: &IdxTracker) -> Color { + let s = BuildHasherDefault::::default(); + let mut hasher = s.build_hasher(); + idxs.0.hash(&mut hasher); + hasher.finish() + } + + pub fn len(&self) -> usize { + self.colors.len() + } + + pub fn is_empty(&self) -> bool { + self.colors.is_empty() + } + + pub fn contains(&self, color: Color, idx: Idx) -> bool { + if let Some(idxs) = self.colors.get(&color) { + idxs.0.contains(&idx) + } else { + false + } + } + + pub fn indices(&self, color: &Color) -> Indices { + // TODO: what if color is not present? + Indices { + iter: self.colors.get(color).unwrap().0.iter(), + } + } + + pub fn retain(&mut self, f: F) + where + F: FnMut(&Color, &mut IdxTracker) -> bool, + { + self.colors.retain(f) + } +} + +pub struct Indices<'a> { + iter: vec_collections::VecSetIter>, +} + +impl<'a> Iterator for Indices<'a> { + type Item = &'a Idx; + + fn next(&mut self) -> Option { + self.iter.next() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn colors_update() { + let mut colors = Colors::new(); + + let color = colors.update(None, &[1_u64]).unwrap(); + assert_eq!(colors.len(), 1); + + dbg!("update"); + let new_color = colors.update(Some(color), &[1_u64]).unwrap(); + assert_eq!(colors.len(), 1); + assert_eq!(color, new_color); + + dbg!("upgrade"); + let new_color = colors.update(Some(color), &[2_u64]).unwrap(); + assert_eq!(colors.len(), 2); + assert_ne!(color, new_color); + } + + #[test] + fn colors_retain() { + let mut colors = Colors::new(); + + let color1 = colors.update(None, &[1_u64]).unwrap(); + assert_eq!(colors.len(), 1); + // used_colors: + // color1: 1 + + dbg!("update"); + let same_color = colors.update(Some(color1), &[1_u64]).unwrap(); + assert_eq!(colors.len(), 1); + assert_eq!(color1, same_color); + // used_colors: + // color1: 2 + + dbg!("upgrade"); + let color2 = colors.update(Some(color1), &[2_u64]).unwrap(); + assert_eq!(colors.len(), 2); + assert_ne!(color1, color2); + // used_colors: + // color1: 1 + // color2: 1 + + dbg!("update"); + let same_color = colors.update(Some(color2), &[2_u64]).unwrap(); + assert_eq!(colors.len(), 2); + assert_eq!(color2, same_color); + // used_colors: + // color1: 1 + // color1: 2 + + dbg!("upgrade"); + let color3 = colors.update(Some(color1), &[3_u64]).unwrap(); + assert_ne!(color1, color3); + assert_ne!(color2, color3); + // used_colors: + // color1: 0 + // color2: 2 + // color3: 1 + + // This is the pre color-count tracker, where it is needed + // to call retain to maintain colors + //assert_eq!(colors.len(), 3); + //colors.retain(|c, _| [color2, color3].contains(c)); + + assert_eq!(colors.len(), 2); + } +} diff --git a/src/core/src/ffi/index/mod.rs b/src/core/src/ffi/index/mod.rs new file mode 100644 index 0000000000..932a97b222 --- /dev/null +++ b/src/core/src/ffi/index/mod.rs @@ -0,0 +1,37 @@ +pub mod revindex; + +use crate::signature::Signature; + +use crate::ffi::signature::SourmashSignature; +use crate::ffi::utils::{ForeignObject, SourmashStr}; + +pub struct SourmashSearchResult; + +impl ForeignObject for SourmashSearchResult { + type RustObject = (f64, Signature, String); +} + +#[no_mangle] +pub unsafe extern "C" fn searchresult_free(ptr: *mut SourmashSearchResult) { + SourmashSearchResult::drop(ptr); +} + +#[no_mangle] +pub unsafe extern "C" fn searchresult_score(ptr: *const SourmashSearchResult) -> f64 { + let result = SourmashSearchResult::as_rust(ptr); + result.0 +} + +#[no_mangle] +pub unsafe extern "C" fn searchresult_filename(ptr: *const SourmashSearchResult) -> SourmashStr { + let result = SourmashSearchResult::as_rust(ptr); + (result.2).clone().into() +} + +#[no_mangle] +pub unsafe extern "C" fn searchresult_signature( + ptr: *const SourmashSearchResult, +) -> *mut SourmashSignature { + let result = SourmashSearchResult::as_rust(ptr); + SourmashSignature::from_rust((result.1).clone()) +} diff --git a/src/core/src/ffi/index/revindex.rs b/src/core/src/ffi/index/revindex.rs new file mode 100644 index 0000000000..3597121bce --- /dev/null +++ b/src/core/src/ffi/index/revindex.rs @@ -0,0 +1,250 @@ +use std::path::PathBuf; +use std::slice; + +use crate::index::revindex::RevIndex; +use crate::index::Index; +use crate::signature::{Signature, SigsTrait}; +use crate::sketch::minhash::KmerMinHash; +use crate::sketch::Sketch; + +use crate::ffi::index::SourmashSearchResult; +use crate::ffi::minhash::SourmashKmerMinHash; +use crate::ffi::signature::SourmashSignature; +use crate::ffi::utils::{ForeignObject, SourmashStr}; + +pub struct SourmashRevIndex; + +impl ForeignObject for SourmashRevIndex { + type RustObject = RevIndex; +} + +ffi_fn! { +unsafe fn revindex_new_with_paths( + search_sigs_ptr: *const *const SourmashStr, + insigs: usize, + template_ptr: *const SourmashKmerMinHash, + threshold: usize, + queries_ptr: *const *const SourmashKmerMinHash, + inqueries: usize, + keep_sigs: bool, +) -> Result<*mut SourmashRevIndex> { + let search_sigs: Vec = { + assert!(!search_sigs_ptr.is_null()); + slice::from_raw_parts(search_sigs_ptr, insigs) + .iter() + .map(|path| { + let mut new_path = PathBuf::new(); + new_path.push(SourmashStr::as_rust(*path).as_str()); + new_path + }) + .collect() + }; + + let template = { + assert!(!template_ptr.is_null()); + //TODO: avoid clone here + Sketch::MinHash(SourmashKmerMinHash::as_rust(template_ptr).clone()) + }; + + let queries_vec: Vec; + let queries: Option<&[KmerMinHash]> = if queries_ptr.is_null() { + None + } else { + queries_vec = slice::from_raw_parts(queries_ptr, inqueries) + .iter() + .map(|mh_ptr| + // TODO: avoid this clone + SourmashKmerMinHash::as_rust(*mh_ptr).clone()) + .collect(); + Some(queries_vec.as_ref()) + }; + let revindex = RevIndex::new( + search_sigs.as_ref(), + &template, + threshold, + queries, + keep_sigs, + ); + Ok(SourmashRevIndex::from_rust(revindex)) +} +} + +ffi_fn! { +unsafe fn revindex_new_with_sigs( + search_sigs_ptr: *const *const SourmashSignature, + insigs: usize, + template_ptr: *const SourmashKmerMinHash, + threshold: usize, + queries_ptr: *const *const SourmashKmerMinHash, + inqueries: usize, +) -> Result<*mut SourmashRevIndex> { + let search_sigs: Vec = { + assert!(!search_sigs_ptr.is_null()); + slice::from_raw_parts(search_sigs_ptr, insigs) + .iter() + .map(|sig| SourmashSignature::as_rust(*sig)) + .cloned() + .collect() + }; + + let template = { + assert!(!template_ptr.is_null()); + //TODO: avoid clone here + Sketch::MinHash(SourmashKmerMinHash::as_rust(template_ptr).clone()) + }; + + let queries_vec: Vec; + let queries: Option<&[KmerMinHash]> = if queries_ptr.is_null() { + None + } else { + queries_vec = slice::from_raw_parts(queries_ptr, inqueries) + .iter() + .map(|mh_ptr| + // TODO: avoid this clone + SourmashKmerMinHash::as_rust(*mh_ptr).clone()) + .collect(); + Some(queries_vec.as_ref()) + }; + let revindex = RevIndex::new_with_sigs(search_sigs, &template, threshold, queries); + Ok(SourmashRevIndex::from_rust(revindex)) +} +} + +#[no_mangle] +pub unsafe extern "C" fn revindex_free(ptr: *mut SourmashRevIndex) { + SourmashRevIndex::drop(ptr); +} + +ffi_fn! { +unsafe fn revindex_search( + ptr: *const SourmashRevIndex, + sig_ptr: *const SourmashSignature, + threshold: f64, + do_containment: bool, + _ignore_abundance: bool, + size: *mut usize, +) -> Result<*const *const SourmashSearchResult> { + let revindex = SourmashRevIndex::as_rust(ptr); + let sig = SourmashSignature::as_rust(sig_ptr); + + if sig.signatures.is_empty() { + *size = 0; + return Ok(std::ptr::null::<*const SourmashSearchResult>()); + } + + let mh = if let Sketch::MinHash(mh) = &sig.signatures[0] { + mh + } else { + // TODO: what if it is not a mh? + unimplemented!() + }; + + let results: Vec<(f64, Signature, String)> = revindex + .find_signatures(mh, threshold, do_containment, true)? + .into_iter() + .collect(); + + // FIXME: use the ForeignObject trait, maybe define new method there... + let ptr_sigs: Vec<*const SourmashSearchResult> = results + .into_iter() + .map(|x| Box::into_raw(Box::new(x)) as *const SourmashSearchResult) + .collect(); + + let b = ptr_sigs.into_boxed_slice(); + *size = b.len(); + + Ok(Box::into_raw(b) as *const *const SourmashSearchResult) +} +} + +ffi_fn! { +unsafe fn revindex_gather( + ptr: *const SourmashRevIndex, + sig_ptr: *const SourmashSignature, + threshold: f64, + _do_containment: bool, + _ignore_abundance: bool, + size: *mut usize, +) -> Result<*const *const SourmashSearchResult> { + let revindex = SourmashRevIndex::as_rust(ptr); + let sig = SourmashSignature::as_rust(sig_ptr); + + if sig.signatures.is_empty() { + *size = 0; + return Ok(std::ptr::null::<*const SourmashSearchResult>()); + } + + let mh = if let Sketch::MinHash(mh) = &sig.signatures[0] { + mh + } else { + // TODO: what if it is not a mh? + unimplemented!() + }; + + // TODO: proper threshold calculation + let threshold: usize = (threshold * (mh.size() as f64)) as _; + + let counter = revindex.counter_for_query(mh); + dbg!(&counter); + + let results: Vec<(f64, Signature, String)> = revindex + .gather(counter, threshold, mh) + .unwrap() // TODO: proper error handling + .into_iter() + .map(|r| { + let filename = r.filename().to_owned(); + let sig = r.get_match(); + (r.f_match(), sig, filename) + }) + .collect(); + + // FIXME: use the ForeignObject trait, maybe define new method there... + let ptr_sigs: Vec<*const SourmashSearchResult> = results + .into_iter() + .map(|x| Box::into_raw(Box::new(x)) as *const SourmashSearchResult) + .collect(); + + let b = ptr_sigs.into_boxed_slice(); + *size = b.len(); + + Ok(Box::into_raw(b) as *const *const SourmashSearchResult) +} +} + +#[no_mangle] +pub unsafe extern "C" fn revindex_scaled(ptr: *const SourmashRevIndex) -> u64 { + let revindex = SourmashRevIndex::as_rust(ptr); + if let Sketch::MinHash(mh) = revindex.template() { + mh.scaled() + } else { + unimplemented!() + } +} + +#[no_mangle] +pub unsafe extern "C" fn revindex_len(ptr: *const SourmashRevIndex) -> u64 { + let revindex = SourmashRevIndex::as_rust(ptr); + revindex.len() as u64 +} + +ffi_fn! { +unsafe fn revindex_signatures( + ptr: *const SourmashRevIndex, + size: *mut usize, +) -> Result<*mut *mut SourmashSignature> { + let revindex = SourmashRevIndex::as_rust(ptr); + + let sigs = revindex.signatures(); + + // FIXME: use the ForeignObject trait, maybe define new method there... + let ptr_sigs: Vec<*mut SourmashSignature> = sigs + .into_iter() + .map(|x| Box::into_raw(Box::new(x)) as *mut SourmashSignature) + .collect(); + + let b = ptr_sigs.into_boxed_slice(); + *size = b.len(); + + Ok(Box::into_raw(b) as *mut *mut SourmashSignature) +} +} diff --git a/src/core/src/ffi/mod.rs b/src/core/src/ffi/mod.rs index bfd9b46bd7..e9f276d5e2 100644 --- a/src/core/src/ffi/mod.rs +++ b/src/core/src/ffi/mod.rs @@ -8,6 +8,7 @@ pub mod utils; pub mod cmd; pub mod hyperloglog; +pub mod index; pub mod minhash; pub mod nodegraph; pub mod signature; diff --git a/src/core/src/ffi/utils.rs b/src/core/src/ffi/utils.rs index 04652cbeef..01f2221690 100644 --- a/src/core/src/ffi/utils.rs +++ b/src/core/src/ffi/utils.rs @@ -314,3 +314,7 @@ pub unsafe extern "C" fn sourmash_str_free(s: *mut SourmashStr) { (*s).free() } } + +impl ForeignObject for SourmashStr { + type RustObject = SourmashStr; +} diff --git a/src/core/src/index/linear.rs b/src/core/src/index/linear.rs index 3208acdc05..c6e14b64f6 100644 --- a/src/core/src/index/linear.rs +++ b/src/core/src/index/linear.rs @@ -2,20 +2,19 @@ use std::fs::File; use std::io::{BufReader, Read}; use std::path::Path; use std::path::PathBuf; -use std::rc::Rc; use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; use crate::index::{Comparable, DatasetInfo, Index, SigStore}; use crate::prelude::*; -use crate::storage::{FSStorage, Storage, StorageInfo}; +use crate::storage::{FSStorage, InnerStorage, Storage, StorageInfo}; use crate::Error; #[derive(TypedBuilder)] pub struct LinearIndex { #[builder(default)] - storage: Option>, + storage: Option, #[builder(default)] datasets: Vec>, @@ -85,7 +84,7 @@ where pub fn save_file>( &mut self, path: P, - storage: Option>, + storage: Option, ) -> Result<(), Error> { let ref_path = path.as_ref(); let mut basename = ref_path.file_name().unwrap().to_str().unwrap().to_owned(); @@ -98,7 +97,7 @@ where Some(s) => s, None => { let subdir = format!(".linear.{}", basename); - Rc::new(FSStorage::new(location.to_str().unwrap(), &subdir)) + InnerStorage::new(FSStorage::new(location.to_str().unwrap(), &subdir)) } }; @@ -119,7 +118,7 @@ where let _: &L = (*l).data().unwrap(); // set storage to new one - l.storage = Some(Rc::clone(&storage)); + l.storage = Some(storage.clone()); let filename = (*l).save(&l.filename).unwrap(); @@ -164,23 +163,23 @@ where // TODO: support other storages let mut st: FSStorage = (&linear.storage.args).into(); st.set_base(path.as_ref().to_str().unwrap()); - let storage: Rc = Rc::new(st); + let storage = InnerStorage::new(st); Ok(LinearIndex { - storage: Some(Rc::clone(&storage)), + storage: Some(storage.clone()), datasets: linear .leaves .into_iter() .map(|l| { let mut v: SigStore = l.into(); - v.storage = Some(Rc::clone(&storage)); + v.storage = Some(storage.clone()); v }) .collect(), }) } - pub fn storage(&self) -> Option> { + pub fn storage(&self) -> Option { self.storage.clone() } } diff --git a/src/core/src/index/mod.rs b/src/core/src/index/mod.rs index 9dda675edf..9ed78af93a 100644 --- a/src/core/src/index/mod.rs +++ b/src/core/src/index/mod.rs @@ -5,13 +5,13 @@ pub mod bigsi; pub mod linear; +pub mod revindex; pub mod sbt; pub mod search; use std::ops::Deref; use std::path::Path; -use std::rc::Rc; use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; @@ -24,7 +24,7 @@ use crate::prelude::*; use crate::signature::SigsTrait; use crate::sketch::nodegraph::Nodegraph; use crate::sketch::Sketch; -use crate::storage::Storage; +use crate::storage::{InnerStorage, Storage}; use crate::Error; pub type MHBT = SBT, Signature>; @@ -98,6 +98,14 @@ pub trait Index<'a> { fn signature_refs(&self) -> Vec<&Self::Item>; + fn len(&self) -> usize { + self.signature_refs().len() + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + /* fn iter_signatures(&self) -> Self::SignatureIterator; */ @@ -134,7 +142,7 @@ pub struct SigStore { #[builder(setter(into))] metadata: String, - storage: Option>, + storage: Option, #[builder(setter(into), default)] data: OnceCell, diff --git a/src/core/src/index/revindex.rs b/src/core/src/index/revindex.rs new file mode 100644 index 0000000000..c39fe79482 --- /dev/null +++ b/src/core/src/index/revindex.rs @@ -0,0 +1,699 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use getset::{CopyGetters, Getters, Setters}; +use log::{debug, info}; +use nohash_hasher::BuildNoHashHasher; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +use crate::encodings::{Color, Colors, Idx}; +use crate::index::Index; +use crate::signature::{Signature, SigsTrait}; +use crate::sketch::minhash::KmerMinHash; +use crate::sketch::Sketch; +use crate::Error; +use crate::HashIntoType; + +type SigCounter = counter::Counter; + +#[derive(Serialize, Deserialize)] +struct HashToColor(HashMap>); + +impl HashToColor { + fn new() -> Self { + HashToColor(HashMap::< + HashIntoType, + Color, + BuildNoHashHasher, + >::with_hasher(BuildNoHashHasher::default())) + } + + fn get(&self, hash: &HashIntoType) -> Option<&Color> { + self.0.get(hash) + } + + fn retain(&mut self, hashes: &HashSet) { + self.0.retain(|hash, _| hashes.contains(hash)) + } + + fn len(&self) -> usize { + self.0.len() + } + + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + fn add_to(&mut self, colors: &mut Colors, dataset_id: usize, matched_hashes: Vec) { + let mut color = None; + + matched_hashes.into_iter().for_each(|hash| { + color = Some(colors.update(color, &[dataset_id as Idx]).unwrap()); + self.0.insert(hash, color.unwrap()); + }); + } + + fn reduce_hashes_colors( + a: (HashToColor, Colors), + b: (HashToColor, Colors), + ) -> (HashToColor, Colors) { + let ((small_hashes, small_colors), (mut large_hashes, mut large_colors)) = + if a.0.len() > b.0.len() { + (b, a) + } else { + (a, b) + }; + + small_hashes.0.into_iter().for_each(|(hash, color)| { + large_hashes + .0 + .entry(hash) + .and_modify(|entry| { + // Hash is already present. + // Update the current color by adding the indices from + // small_colors. + let ids = small_colors.indices(&color); + let new_color = large_colors.update(Some(*entry), ids).unwrap(); + *entry = new_color; + }) + .or_insert_with(|| { + // In this case, the hash was not present yet. + // we need to create the same color from small_colors + // into large_colors. + let ids = small_colors.indices(&color); + let new_color = large_colors.update(None, ids).unwrap(); + assert_eq!(new_color, color); + new_color + }); + }); + + (large_hashes, large_colors) + } +} + +// Use rkyv for serialization? +// https://davidkoloski.me/rkyv/ +#[derive(Serialize, Deserialize)] +pub struct RevIndex { + hash_to_color: HashToColor, + + sig_files: Vec, + + #[serde(skip)] + ref_sigs: Option>, + + template: Sketch, + colors: Colors, + //#[serde(skip)] + //storage: Option, +} + +impl RevIndex { + pub fn load>( + index_path: P, + queries: Option<&[KmerMinHash]>, + ) -> Result> { + let (rdr, _) = niffler::from_path(index_path)?; + let revindex = if let Some(qs) = queries { + // TODO: avoid loading full revindex if query != None + /* + struct PartialRevIndex { + hashes_to_keep: Option>, + marker: PhantomData T>, + } + + impl PartialRevIndex { + pub fn new(hashes_to_keep: HashSet) -> Self { + PartialRevIndex { + hashes_to_keep: Some(hashes_to_keep), + marker: PhantomData, + } + } + } + */ + + let mut hashes: HashSet = HashSet::new(); + for q in qs { + hashes.extend(q.iter_mins()); + } + + //let mut revindex: RevIndex = PartialRevIndex::new(hashes).deserialize(&rdr).unwrap(); + + let mut revindex: RevIndex = serde_json::from_reader(rdr)?; + revindex.hash_to_color.retain(&hashes); + revindex + } else { + // Load the full revindex + serde_json::from_reader(rdr)? + }; + + Ok(revindex) + } + + pub fn new( + search_sigs: &[PathBuf], + template: &Sketch, + threshold: usize, + queries: Option<&[KmerMinHash]>, + keep_sigs: bool, + ) -> RevIndex { + // If threshold is zero, let's merge all queries and save time later + let merged_query = queries.and_then(|qs| Self::merge_queries(qs, threshold)); + + let processed_sigs = AtomicUsize::new(0); + + #[cfg(feature = "parallel")] + let sig_iter = search_sigs.par_iter(); + + #[cfg(not(feature = "parallel"))] + let sig_iter = search_sigs.iter(); + + let filtered_sigs = sig_iter.enumerate().filter_map(|(dataset_id, filename)| { + let i = processed_sigs.fetch_add(1, Ordering::SeqCst); + if i % 1000 == 0 { + info!("Processed {} reference sigs", i); + } + + let search_sig = Signature::from_path(&filename) + .unwrap_or_else(|_| panic!("Error processing {:?}", filename)) + .swap_remove(0); + + RevIndex::map_hashes_colors( + dataset_id, + &search_sig, + queries, + &merged_query, + threshold, + template, + ) + }); + + #[cfg(feature = "parallel")] + let (hash_to_color, colors) = filtered_sigs.reduce( + || (HashToColor::new(), Colors::default()), + HashToColor::reduce_hashes_colors, + ); + + #[cfg(not(feature = "parallel"))] + let (hash_to_color, colors) = filtered_sigs.fold( + (HashToColor::new(), Colors::default()), + HashToColor::reduce_hashes_colors, + ); + + // TODO: build this together with hash_to_idx? + let ref_sigs = if keep_sigs { + #[cfg(feature = "parallel")] + let sigs_iter = search_sigs.par_iter(); + + #[cfg(not(feature = "parallel"))] + let sigs_iter = search_sigs.iter(); + + Some( + sigs_iter + .map(|ref_path| { + Signature::from_path(&ref_path) + .unwrap_or_else(|_| panic!("Error processing {:?}", ref_path)) + .swap_remove(0) + }) + .collect(), + ) + } else { + None + }; + + RevIndex { + hash_to_color, + sig_files: search_sigs.into(), + ref_sigs, + template: template.clone(), + colors, + // storage: Some(InnerStorage::new(MemStorage::default())), + } + } + + fn merge_queries(qs: &[KmerMinHash], threshold: usize) -> Option { + if threshold == 0 { + let mut merged = qs[0].clone(); + for query in &qs[1..] { + merged.merge(query).unwrap(); + } + Some(merged) + } else { + None + } + } + + pub fn new_with_sigs( + search_sigs: Vec, + template: &Sketch, + threshold: usize, + queries: Option<&[KmerMinHash]>, + ) -> RevIndex { + // If threshold is zero, let's merge all queries and save time later + let merged_query = queries.and_then(|qs| Self::merge_queries(qs, threshold)); + + let processed_sigs = AtomicUsize::new(0); + + #[cfg(feature = "parallel")] + let sigs_iter = search_sigs.par_iter(); + #[cfg(not(feature = "parallel"))] + let sigs_iter = search_sigs.iter(); + + let filtered_sigs = sigs_iter.enumerate().filter_map(|(dataset_id, sig)| { + let i = processed_sigs.fetch_add(1, Ordering::SeqCst); + if i % 1000 == 0 { + info!("Processed {} reference sigs", i); + } + + RevIndex::map_hashes_colors( + dataset_id, + sig, + queries, + &merged_query, + threshold, + template, + ) + }); + + #[cfg(feature = "parallel")] + let (hash_to_color, colors) = filtered_sigs.reduce( + || (HashToColor::new(), Colors::default()), + HashToColor::reduce_hashes_colors, + ); + + #[cfg(not(feature = "parallel"))] + let (hash_to_color, colors) = filtered_sigs.fold( + (HashToColor::new(), Colors::default()), + HashToColor::reduce_hashes_colors, + ); + + RevIndex { + hash_to_color, + sig_files: vec![], + ref_sigs: search_sigs.into(), + template: template.clone(), + colors, + //storage: None, + } + } + + fn map_hashes_colors( + dataset_id: usize, + search_sig: &Signature, + queries: Option<&[KmerMinHash]>, + merged_query: &Option, + threshold: usize, + template: &Sketch, + ) -> Option<(HashToColor, Colors)> { + let mut search_mh = None; + if let Some(Sketch::MinHash(mh)) = search_sig.select_sketch(template) { + search_mh = Some(mh); + } + + let search_mh = search_mh.expect("Couldn't find a compatible MinHash"); + let mut hash_to_color = HashToColor::new(); + let mut colors = Colors::default(); + + if let Some(qs) = queries { + if let Some(ref merged) = merged_query { + let (matched_hashes, intersection) = merged.intersection(search_mh).unwrap(); + if !matched_hashes.is_empty() || intersection > threshold as u64 { + hash_to_color.add_to(&mut colors, dataset_id, matched_hashes); + } + } else { + for query in qs { + let (matched_hashes, intersection) = query.intersection(search_mh).unwrap(); + if !matched_hashes.is_empty() || intersection > threshold as u64 { + hash_to_color.add_to(&mut colors, dataset_id, matched_hashes); + } + } + } + } else { + let matched = search_mh.mins(); + let size = matched.len() as u64; + if !matched.is_empty() || size > threshold as u64 { + hash_to_color.add_to(&mut colors, dataset_id, matched); + } + }; + + if hash_to_color.is_empty() { + None + } else { + Some((hash_to_color, colors)) + } + } + + pub fn search( + &self, + counter: SigCounter, + similarity: bool, + threshold: usize, + ) -> Result, Box> { + let mut matches = vec![]; + if similarity { + unimplemented!("TODO: threshold correction") + } + + for (dataset_id, size) in counter.most_common() { + if size >= threshold { + matches.push(self.sig_files[dataset_id as usize].to_str().unwrap().into()); + } else { + break; + }; + } + Ok(matches) + } + + pub fn gather( + &self, + mut counter: SigCounter, + threshold: usize, + query: &KmerMinHash, + ) -> Result, Box> { + let mut match_size = usize::max_value(); + let mut matches = vec![]; + + while match_size > threshold && !counter.is_empty() { + let (dataset_id, size) = counter.most_common()[0]; + match_size = if size >= threshold { size } else { break }; + + let p; + let match_path = if self.sig_files.is_empty() { + p = PathBuf::new(); // TODO: Fix somehow? + &p + } else { + &self.sig_files[dataset_id as usize] + }; + + let ref_match; + let match_sig = if let Some(refsigs) = &self.ref_sigs { + &refsigs[dataset_id as usize] + } else { + // TODO: remove swap_remove + ref_match = Signature::from_path(&match_path)?.swap_remove(0); + &ref_match + }; + + let mut match_mh = None; + if let Some(Sketch::MinHash(mh)) = match_sig.select_sketch(&self.template) { + match_mh = Some(mh); + } + let match_mh = match_mh.expect("Couldn't find a compatible MinHash"); + + // Calculate stats + let f_orig_query = match_size as f64 / query.size() as f64; + let f_match = match_size as f64 / match_mh.size() as f64; + let filename = match_path.to_str().unwrap().into(); + let name = match_sig.name(); + let unique_intersect_bp = match_mh.scaled() as usize * match_size; + let gather_result_rank = matches.len(); + + let (intersect_orig, _) = match_mh.intersection_size(query)?; + let intersect_bp = (match_mh.scaled() as u64 * intersect_orig) as usize; + + let f_unique_to_query = intersect_orig as f64 / query.size() as f64; + let match_ = match_sig.clone(); + + // TODO: all of these + let f_unique_weighted = 0.; + let average_abund = 0; + let median_abund = 0; + let std_abund = 0; + let md5 = "".into(); + let f_match_orig = 0.; + let remaining_bp = 0; + + let result = GatherResult { + intersect_bp, + f_orig_query, + f_match, + f_unique_to_query, + f_unique_weighted, + average_abund, + median_abund, + std_abund, + filename, + name, + md5, + match_, + f_match_orig, + unique_intersect_bp, + gather_result_rank, + remaining_bp, + }; + matches.push(result); + + // Prepare counter for finding the next match by decrementing + // all hashes found in the current match in other datasets + for hash in match_mh.iter_mins() { + if let Some(color) = self.hash_to_color.get(hash) { + for dataset in self.colors.indices(color) { + counter.entry(*dataset).and_modify(|e| { + if *e > 0 { + *e -= 1 + } + }); + } + } + } + counter.remove(&dataset_id); + } + Ok(matches) + } + + pub fn counter_for_query(&self, query: &KmerMinHash) -> SigCounter { + query + .iter_mins() + .filter_map(|hash| self.hash_to_color.get(hash)) + .flat_map(|color| self.colors.indices(color)) + .cloned() + .collect() + } + + pub fn template(&self) -> Sketch { + self.template.clone() + } + + // TODO: mh should be a sketch, or even a sig... + pub(crate) fn find_signatures( + &self, + mh: &KmerMinHash, + threshold: f64, + containment: bool, + _ignore_scaled: bool, + ) -> Result, Error> { + /* + let template_mh = None; + if let Sketch::MinHash(mh) = self.template { + template_mh = Some(mh); + }; + // TODO: throw error + let template_mh = template_mh.unwrap(); + + let tmp_mh; + let mh = if template_mh.scaled() > mh.scaled() { + // TODO: proper error here + tmp_mh = mh.downsample_scaled(self.scaled)?; + &tmp_mh + } else { + mh + }; + + if self.scaled < mh.scaled() && !ignore_scaled { + return Err(LcaDBError::ScaledMismatchError { + db: self.scaled, + query: mh.scaled(), + } + .into()); + } + */ + + // TODO: proper threshold calculation + let threshold: usize = (threshold * (mh.size() as f64)) as _; + + let counter = self.counter_for_query(mh); + + debug!( + "number of matching signatures for hashes: {}", + counter.len() + ); + + let mut results = vec![]; + for (dataset_id, size) in counter.most_common() { + let match_size = if size >= threshold { size } else { break }; + + let p; + let match_path = if self.sig_files.is_empty() { + p = PathBuf::new(); // TODO: Fix somehow? + &p + } else { + &self.sig_files[dataset_id as usize] + }; + + let ref_match; + let match_sig = if let Some(refsigs) = &self.ref_sigs { + &refsigs[dataset_id as usize] + } else { + // TODO: remove swap_remove + ref_match = Signature::from_path(&match_path)?.swap_remove(0); + &ref_match + }; + + let mut match_mh = None; + if let Some(Sketch::MinHash(mh)) = match_sig.select_sketch(&self.template) { + match_mh = Some(mh); + } + let match_mh = match_mh.unwrap(); + + if size >= threshold { + let score = if containment { + size as f64 / mh.size() as f64 + } else { + size as f64 / (mh.size() + match_size - size) as f64 + }; + let filename = match_path.to_str().unwrap().into(); + let mut sig = match_sig.clone(); + sig.reset_sketches(); + sig.push(Sketch::MinHash(match_mh.clone())); + results.push((score, sig, filename)); + } else { + break; + }; + } + Ok(results) + } +} + +#[derive(CopyGetters, Getters, Setters, Serialize, Deserialize, Debug)] +pub struct GatherResult { + #[getset(get_copy = "pub")] + intersect_bp: usize, + + #[getset(get_copy = "pub")] + f_orig_query: f64, + + #[getset(get_copy = "pub")] + f_match: f64, + + f_unique_to_query: f64, + f_unique_weighted: f64, + average_abund: usize, + median_abund: usize, + std_abund: usize, + + #[getset(get = "pub")] + filename: String, + + #[getset(get = "pub")] + name: String, + + md5: String, + match_: Signature, + f_match_orig: f64, + unique_intersect_bp: usize, + gather_result_rank: usize, + remaining_bp: usize, +} + +impl GatherResult { + pub fn get_match(&self) -> Signature { + self.match_.clone() + } +} + +impl<'a> Index<'a> for RevIndex { + type Item = Signature; + + fn insert(&mut self, _node: Self::Item) -> Result<(), Error> { + unimplemented!() + } + + fn save>(&self, _path: P) -> Result<(), Error> { + unimplemented!() + } + + fn load>(_path: P) -> Result<(), Error> { + unimplemented!() + } + + fn len(&self) -> usize { + if let Some(refs) = &self.ref_sigs { + refs.len() + } else { + self.sig_files.len() + } + } + + fn signatures(&self) -> Vec { + if let Some(ref sigs) = self.ref_sigs { + sigs.to_vec() + } else { + unimplemented!() + } + } + + fn signature_refs(&self) -> Vec<&Self::Item> { + unimplemented!() + } +} + +#[cfg(test)] +mod test { + use super::*; + + use crate::sketch::minhash::max_hash_for_scaled; + + #[test] + fn revindex_new() { + let max_hash = max_hash_for_scaled(10000); + let template = Sketch::MinHash( + KmerMinHash::builder() + .num(0u32) + .ksize(31) + .max_hash(max_hash) + .build(), + ); + let search_sigs = [ + "../../tests/test-data/gather/GCF_000006945.2_ASM694v2_genomic.fna.gz.sig".into(), + "../../tests/test-data/gather/GCF_000007545.1_ASM754v1_genomic.fna.gz.sig".into(), + ]; + let index = RevIndex::new(&search_sigs, &template, 0, None, false); + assert_eq!(index.colors.len(), 3); + } + + #[test] + fn revindex_many() { + let max_hash = max_hash_for_scaled(10000); + let template = Sketch::MinHash( + KmerMinHash::builder() + .num(0u32) + .ksize(31) + .max_hash(max_hash) + .build(), + ); + let search_sigs = [ + "../../tests/test-data/gather/GCF_000006945.2_ASM694v2_genomic.fna.gz.sig".into(), + "../../tests/test-data/gather/GCF_000007545.1_ASM754v1_genomic.fna.gz.sig".into(), + "../../tests/test-data/gather/GCF_000008105.1_ASM810v1_genomic.fna.gz.sig".into(), + ]; + + let index = RevIndex::new(&search_sigs, &template, 0, None, false); + /* + dbg!(&index.colors.colors); + 0: 86 + 1: 132 + 2: 91 + (0, 1): 53 + (0, 2): 90 + (1, 2): 26 + (0, 1, 2): 261 + union: 739 + */ + //assert_eq!(index.colors.len(), 3); + assert_eq!(index.colors.len(), 7); + } +} diff --git a/src/core/src/index/sbt/mhbt.rs b/src/core/src/index/sbt/mhbt.rs index 67403cd334..2d4ceb3fb8 100644 --- a/src/core/src/index/sbt/mhbt.rs +++ b/src/core/src/index/sbt/mhbt.rs @@ -7,6 +7,7 @@ use crate::prelude::*; use crate::signature::SigsTrait; use crate::sketch::nodegraph::Nodegraph; use crate::sketch::Sketch; +use crate::storage::Storage; use crate::Error; impl ToWriter for Nodegraph { diff --git a/src/core/src/index/sbt/mod.rs b/src/core/src/index/sbt/mod.rs index b4baaca76b..4f4a7b82f9 100644 --- a/src/core/src/index/sbt/mod.rs +++ b/src/core/src/index/sbt/mod.rs @@ -16,7 +16,6 @@ use std::fs::File; use std::hash::BuildHasherDefault; use std::io::{BufReader, Read}; use std::path::{Path, PathBuf}; -use std::rc::Rc; use log::info; use nohash_hasher::NoHashHasher; @@ -26,7 +25,7 @@ use typed_builder::TypedBuilder; use crate::index::{Comparable, DatasetInfo, Index, SigStore}; use crate::prelude::*; -use crate::storage::{FSStorage, StorageInfo}; +use crate::storage::{FSStorage, InnerStorage, StorageInfo}; use crate::Error; #[derive(TypedBuilder)] @@ -35,7 +34,7 @@ pub struct SBT { d: u32, #[builder(default, setter(into))] - storage: Option>, + storage: Option, #[builder(default = Factory::GraphFactory { args: (1, 100000.0, 4) })] factory: Factory, @@ -79,7 +78,7 @@ where (0..u64::from(self.d)).map(|c| self.child(pos, c)).collect() } - pub fn storage(&self) -> Option> { + pub fn storage(&self) -> Option { self.storage.clone() } @@ -148,7 +147,7 @@ where SBTInfo::V6(ref sbt) => (&sbt.storage.args).into(), }; st.set_base(path.as_ref().to_str().unwrap()); - let storage: Rc = Rc::new(st); + let storage = InnerStorage::new(st); let d = match sinfo { SBTInfo::V4(ref sbt) => sbt.d, @@ -174,7 +173,7 @@ where .filename(l.filename) .name(l.name) .metadata(l.metadata) - .storage(Some(Rc::clone(&storage))) + .storage(Some(storage.clone())) .build(), ) }) @@ -189,7 +188,7 @@ where .filename(l.filename) .name(l.name) .metadata(l.metadata) - .storage(Some(Rc::clone(&storage))) + .storage(Some(storage.clone())) .build(), ) }) @@ -207,7 +206,7 @@ where .filename(l.filename) .name(l.name) .metadata(l.metadata) - .storage(Some(Rc::clone(&storage))) + .storage(Some(storage.clone())) .build(), ) }) @@ -222,7 +221,7 @@ where .filename(l.filename) .name(l.name) .metadata(l.metadata) - .storage(Some(Rc::clone(&storage))) + .storage(Some(storage.clone())) .build(), ) }) @@ -240,7 +239,7 @@ where .filename(l.filename.clone()) .name(l.name.clone()) .metadata(l.metadata.clone()) - .storage(Some(Rc::clone(&storage))) + .storage(Some(storage.clone())) .build(), )), NodeInfoV4::Leaf(_) => None, @@ -258,7 +257,7 @@ where .filename(l.filename) .name(l.name) .metadata(l.metadata) - .storage(Some(Rc::clone(&storage))) + .storage(Some(storage.clone())) .build(), )), }) @@ -271,7 +270,7 @@ where Ok(SBT { d, factory, - storage: Some(Rc::clone(&storage)), + storage: Some(storage), nodes, leaves, }) @@ -295,7 +294,7 @@ where pub fn save_file>( &mut self, path: P, - storage: Option>, + storage: Option, ) -> Result<(), Error> { let ref_path = path.as_ref(); let mut basename = ref_path.file_name().unwrap().to_str().unwrap().to_owned(); @@ -308,7 +307,7 @@ where Some(s) => s, None => { let subdir = format!(".sbt.{}", basename); - Rc::new(FSStorage::new(location.to_str().unwrap(), &subdir)) + InnerStorage::new(FSStorage::new(location.to_str().unwrap(), &subdir)) } }; @@ -331,7 +330,7 @@ where let _: &U = (*l).data().expect("Couldn't load data"); // set storage to new one - l.storage = Some(Rc::clone(&storage)); + l.storage = Some(storage.clone()); let filename = (*l).save(&l.filename).unwrap(); let new_node = NodeInfo { @@ -350,7 +349,7 @@ where let _: &T = (*l).data().unwrap(); // set storage to new one - l.storage = Some(Rc::clone(&storage)); + l.storage = Some(storage.clone()); // TODO: this should be l.md5sum(), not l.filename let filename = (*l).save(&l.filename).unwrap(); @@ -558,7 +557,7 @@ pub struct Node { metadata: HashMap, #[builder(default)] - storage: Option>, + storage: Option, #[builder(setter(into), default)] data: OnceCell, @@ -696,7 +695,7 @@ struct TreeNode { pub fn scaffold( mut datasets: Vec>, - storage: Option>, + storage: Option, ) -> SBT, Signature> where N: Clone + Default, diff --git a/src/core/src/signature.rs b/src/core/src/signature.rs index 008b522874..3121537130 100644 --- a/src/core/src/signature.rs +++ b/src/core/src/signature.rs @@ -608,9 +608,9 @@ impl Signature { if #[cfg(feature = "parallel")] { self.signatures .par_iter_mut() - .for_each(|sketch| { - sketch.add_sequence(seq, force).unwrap(); } - ); + .try_for_each(|sketch| { + sketch.add_sequence(seq, force) } + )?; } else { for sketch in self.signatures.iter_mut(){ sketch.add_sequence(seq, force)?; diff --git a/src/core/src/sketch/minhash.rs b/src/core/src/sketch/minhash.rs index 658f41d1ec..7557159839 100644 --- a/src/core/src/sketch/minhash.rs +++ b/src/core/src/sketch/minhash.rs @@ -774,6 +774,26 @@ impl KmerMinHash { hll } + + // create a downsampled copy of self + pub fn downsample_scaled(&self, scaled: u64) -> Result { + let max_hash = max_hash_for_scaled(scaled); + + let mut new_mh = KmerMinHash::new( + max_hash, // old max_hash => max_hash arg + self.ksize, + self.hash_function, + self.seed, + self.abunds.is_some(), + self.num, + ); + if self.abunds.is_some() { + new_mh.add_many_with_abund(&self.to_vec_abunds())?; + } else { + new_mh.add_many(&self.mins)?; + } + Ok(new_mh) + } } impl SigsTrait for KmerMinHash { diff --git a/src/core/src/storage.rs b/src/core/src/storage.rs index c1c70f7e08..598f47c7b6 100644 --- a/src/core/src/storage.rs +++ b/src/core/src/storage.rs @@ -1,6 +1,7 @@ use std::fs::{DirBuilder, File}; use std::io::{BufReader, BufWriter, Read, Write}; use std::path::PathBuf; +use std::sync::{Arc, Mutex}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -8,6 +9,39 @@ use typed_builder::TypedBuilder; use crate::Error; +/// An abstraction for any place where we can store data. +pub trait Storage { + /// Save bytes into path + fn save(&self, path: &str, content: &[u8]) -> Result; + + /// Load bytes from path + fn load(&self, path: &str) -> Result, Error>; + + /// Args for initializing a new Storage + fn args(&self) -> StorageArgs; +} + +#[derive(Clone)] +pub struct InnerStorage(Arc>); + +impl InnerStorage { + pub fn new(inner: impl Storage + 'static) -> InnerStorage { + InnerStorage(Arc::new(Mutex::new(inner))) + } +} + +impl Storage for InnerStorage { + fn save(&self, path: &str, content: &[u8]) -> Result { + self.0.save(path, content) + } + fn load(&self, path: &str) -> Result, Error> { + self.0.load(path) + } + fn args(&self) -> StorageArgs { + self.0.args() + } +} + #[derive(Debug, Error)] pub enum StorageError { #[error("Path can't be empty")] @@ -43,16 +77,21 @@ impl From<&StorageArgs> for FSStorage { } } -/// An abstraction for any place where we can store data. -pub trait Storage { - /// Save bytes into path - fn save(&self, path: &str, content: &[u8]) -> Result; +impl Storage for Mutex +where + L: ?Sized + Storage, +{ + fn save(&self, path: &str, content: &[u8]) -> Result { + self.lock().unwrap().save(path, content) + } - /// Load bytes from path - fn load(&self, path: &str) -> Result, Error>; + fn load(&self, path: &str) -> Result, Error> { + self.lock().unwrap().load(path) + } - /// Args for initializing a new Storage - fn args(&self) -> StorageArgs; + fn args(&self) -> StorageArgs { + self.lock().unwrap().args() + } } /// Store files locally into a directory @@ -115,3 +154,90 @@ impl Storage for FSStorage { } } } + +pub struct ZipStorage<'a> { + mapping: Option, + archive: Option>, + //metadata: piz::read::DirectoryContents<'a>, +} + +fn load_from_archive<'a>(archive: &'a piz::ZipArchive<'a>, path: &str) -> Result, Error> { + use piz::read::FileTree; + + // FIXME error + let tree = piz::read::as_tree(archive.entries()).map_err(|_| StorageError::EmptyPathError)?; + // FIXME error + let entry = tree + .lookup(path) + .map_err(|_| StorageError::EmptyPathError)?; + + // FIXME error + let mut reader = BufReader::new( + archive + .read(entry) + .map_err(|_| StorageError::EmptyPathError)?, + ); + let mut contents = Vec::new(); + reader.read_to_end(&mut contents)?; + + Ok(contents) +} + +impl<'a> Storage for ZipStorage<'a> { + fn save(&self, _path: &str, _content: &[u8]) -> Result { + unimplemented!(); + } + + fn load(&self, path: &str) -> Result, Error> { + if let Some(archive) = &self.archive { + load_from_archive(archive, path) + } else { + //FIXME + let archive = piz::ZipArchive::new((&self.mapping.as_ref()).unwrap()) + .map_err(|_| StorageError::EmptyPathError)?; + load_from_archive(&archive, path) + } + } + + fn args(&self) -> StorageArgs { + unimplemented!(); + } +} + +impl<'a> ZipStorage<'a> { + pub fn new(location: &str) -> Result { + let zip_file = File::open(location)?; + let mapping = unsafe { memmap2::Mmap::map(&zip_file)? }; + + //FIXME + //let archive = piz::ZipArchive::new(&mapping).map_err(|_| StorageError::EmptyPathError)?; + + //FIXME + // let tree = + // piz::read::as_tree(archive.entries()).map_err(|_| StorageError::EmptyPathError)?; + + Ok(Self { + mapping: Some(mapping), + archive: None, + //metadata: tree, + }) + } + + pub fn from_slice(mapping: &'a [u8]) -> Result { + //FIXME + let archive = piz::ZipArchive::new(mapping).map_err(|_| StorageError::EmptyPathError)?; + + //FIXME + //let entries: Vec<_> = archive.entries().iter().map(|x| x.to_owned()).collect(); + //let tree = + // piz::read::as_tree(entries.as_slice()).map_err(|_| StorageError::EmptyPathError)?; + + Ok(Self { + archive: Some(archive), + mapping: None, + /* metadata: archive + .as_tree() + .map_err(|_| StorageError::EmptyPathError)?, */ + }) + } +} diff --git a/src/core/tests/storage.rs b/src/core/tests/storage.rs new file mode 100644 index 0000000000..202ddd3fe3 --- /dev/null +++ b/src/core/tests/storage.rs @@ -0,0 +1,37 @@ +use std::fs::File; +use std::path::PathBuf; + +use sourmash::storage::{Storage, ZipStorage}; + +#[test] +fn zipstorage_load_file() -> Result<(), Box> { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/v6.sbt.zip"); + + let zs = ZipStorage::new(filename.to_str().unwrap())?; + + let data = zs.load("v6.sbt.json")?; + + let description: serde_json::Value = serde_json::from_slice(&data[..])?; + assert_eq!(description["version"], 6); + + Ok(()) +} + +#[test] +fn zipstorage_load_slice() -> Result<(), Box> { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/v6.sbt.zip"); + + let zip_file = File::open(filename)?; + let mapping = unsafe { memmap2::Mmap::map(&zip_file)? }; + + let zs = ZipStorage::from_slice(&mapping)?; + + let data = zs.load("v6.sbt.json")?; + + let description: serde_json::Value = serde_json::from_slice(&data[..])?; + assert_eq!(description["version"], 6); + + Ok(()) +} diff --git a/src/sourmash/index.py b/src/sourmash/index/__init__.py similarity index 99% rename from src/sourmash/index.py rename to src/sourmash/index/__init__.py index a2e362d127..e4b68ca58f 100644 --- a/src/sourmash/index.py +++ b/src/sourmash/index/__init__.py @@ -40,10 +40,10 @@ class MultiIndex - in-memory storage and selection of signatures from multiple import csv from io import TextIOWrapper -from .search import make_jaccard_search_query, make_gather_query -from .manifest import CollectionManifest -from .logging import debug_literal -from .signature import load_signatures, save_signatures +from ..search import make_jaccard_search_query, make_gather_query +from ..manifest import CollectionManifest +from ..logging import debug_literal +from ..signature import load_signatures, save_signatures # generic return tuple for Index.search and Index.gather IndexSearchResult = namedtuple('Result', 'score, signature, location') @@ -575,7 +575,7 @@ def save(self, path): @classmethod def load(cls, location, traverse_yield_all=False, use_manifest=True): "Class method to load a zipfile." - from .sbt_storage import ZipStorage + from ..sbt_storage import ZipStorage # we can only load from existing zipfiles in this method. if not os.path.exists(location): @@ -916,7 +916,7 @@ def load_from_directory(cls, pathname, *, force=False): load all files ending in .sig or .sig.gz, by default; if 'force' is True, will attempt to load _all_ files, ignoring errors. """ - from .sourmash_args import traverse_find_sigs + from ..sourmash_args import traverse_find_sigs if not os.path.isdir(pathname): raise ValueError(f"'{pathname}' must be a directory.") @@ -979,7 +979,7 @@ def load_from_pathlist(cls, filename): if they are listed in the text file; it uses 'load_file_as_index' underneath. """ - from .sourmash_args import (load_pathlist_from_file, + from ..sourmash_args import (load_pathlist_from_file, load_file_as_index) idx_list = [] src_list = [] diff --git a/src/sourmash/index/revindex.py b/src/sourmash/index/revindex.py new file mode 100644 index 0000000000..f4346f074c --- /dev/null +++ b/src/sourmash/index/revindex.py @@ -0,0 +1,253 @@ +import weakref + +from sourmash.index import Index, IndexSearchResult +from sourmash.minhash import MinHash +from sourmash.signature import SourmashSignature +from sourmash._lowlevel import ffi, lib +from sourmash.utils import RustObject, rustcall, decode_str, encode_str + + +class RevIndex(RustObject, Index): + __dealloc_func__ = lib.revindex_free + + def __init__( + self, + *, + signatures=None, + signature_paths=None, + template=None, + threshold=0, + queries=None, + keep_sigs=False, + ): + self.template = template + self.threshold = threshold + self.queries = queries + self.keep_sigs = keep_sigs + self.signature_paths = signature_paths + self._signatures = signatures + + if signature_paths is None or signatures is None: + # delay initialization + self._objptr = ffi.NULL + else: + self._init_inner() + + def _init_inner(self): + if self._objptr != ffi.NULL: + # Already initialized + return + + if ( + self.signature_paths is None + and not self._signatures + and self._objptr == ffi.NULL + ): + raise ValueError("No signatures provided") + elif (self.signature_paths or self._signatures) and self._objptr != ffi.NULL: + raise NotImplementedError("Need to update RevIndex") + + attached_refs = weakref.WeakKeyDictionary() + + queries_ptr = ffi.NULL + queries_size = 0 + if self.queries: + # get list of rust objects + collected = [] + for obj in queries: + rv = obj._get_objptr() + attached_refs[rv] = obj + collected.append(rv) + queries_ptr = ffi.new("SourmashSignature*[]", collected) + queries_size = len(queries) + + template_ptr = ffi.NULL + if self.template: + if isinstance(self.template, MinHash): + template_ptr = self.template._get_objptr() + else: + raise ValueError("Template must be a MinHash") + + search_sigs_ptr = ffi.NULL + sigs_size = 0 + collected = [] + if self.signature_paths: + for path in self.signature_paths: + collected.append(encode_str(path)) + search_sigs_ptr = ffi.new("SourmashStr*[]", collected) + sigs_size = len(signature_paths) + + self._objptr = rustcall( + lib.revindex_new_with_paths, + search_sigs_ptr, + sigs_size, + template_ptr, + self.threshold, + queries_ptr, + queries_size, + self.keep_sigs, + ) + elif self._signatures: + # force keep_sigs=True, and pass SourmashSignature directly to RevIndex. + for sig in self._signatures: + collected.append(sig._get_objptr()) + search_sigs_ptr = ffi.new("SourmashSignature*[]", collected) + sigs_size = len(self._signatures) + + self._objptr = rustcall( + lib.revindex_new_with_sigs, + search_sigs_ptr, + sigs_size, + template_ptr, + self.threshold, + queries_ptr, + queries_size, + ) + + def signatures(self): + self._init_inner() + + size = ffi.new("uintptr_t *") + sigs_ptr = self._methodcall(lib.revindex_signatures, size) + size = size[0] + + sigs = [] + for i in range(size): + sig = SourmashSignature._from_objptr(sigs_ptr[i]) + sigs.append(sig) + + for sig in sigs: + yield sig + + #if self._signatures: + # yield from self._signatures + #else: + # raise NotImplementedError("Call into Rust and retrieve sigs") + + def __len__(self): + if self._objptr: + return self._methodcall(lib.revindex_len) + else: + return len(self._signatures) + + def insert(self, node): + if self._signatures is None: + self._signatures = [] + self._signatures.append(node) + + def save(self, path): + pass + + @classmethod + def load(cls, location): + pass + + def select(self, ksize=None, moltype=None): + if self.template: + if ksize: + self.template.ksize = ksize + if moltype: + self.template.moltype = moltype + else: + # TODO: deal with None/default values + self.template = MinHash(ksize=ksize, moltype=moltype) + +# def search(self, query, *args, **kwargs): +# """Return set of matches with similarity above 'threshold'. +# +# Results will be sorted by similarity, highest to lowest. +# +# Optional arguments: +# * do_containment: default False. If True, use Jaccard containment. +# * ignore_abundance: default False. If True, and query signature +# and database support k-mer abundances, ignore those abundances. +# +# Note, the "best only" hint is ignored by LCA_Database +# """ +# if not query.minhash: +# return [] +# +# # check arguments +# if "threshold" not in kwargs: +# raise TypeError("'search' requires 'threshold'") +# threshold = kwargs["threshold"] +# do_containment = kwargs.get("do_containment", False) +# ignore_abundance = kwargs.get("ignore_abundance", False) +# +# self._init_inner() +# +# size = ffi.new("uintptr_t *") +# results_ptr = self._methodcall( +# lib.revindex_search, +# query._get_objptr(), +# threshold, +# do_containment, +# ignore_abundance, +# size, +# ) +# +# size = size[0] +# if size == 0: +# return [] +# +# results = [] +# for i in range(size): +# match = SearchResult._from_objptr(results_ptr[i]) +# if match.score >= threshold: +# results.append(IndexSearchResult(match.score, match.signature, match.filename)) +# +# return results +# +# def gather(self, query, *args, **kwargs): +# "Return the match with the best Jaccard containment in the database." +# if not query.minhash: +# return [] +# +# self._init_inner() +# +# threshold_bp = kwargs.get("threshold_bp", 0.0) +# threshold = threshold_bp / (len(query.minhash) * self.scaled) +# +# results = [] +# size = ffi.new("uintptr_t *") +# results_ptr = self._methodcall( +# lib.revindex_gather, query._get_objptr(), threshold, True, True, size +# ) +# size = size[0] +# if size == 0: +# return [] +# +# results = [] +# for i in range(size): +# match = SearchResult._from_objptr(results_ptr[i]) +# if match.score >= threshold: +# results.append(IndexSearchResult(match.score, match.signature, match.filename)) +# +# results.sort(reverse=True, +# key=lambda x: (x.score, x.signature.md5sum())) +# +# return results[:1] + + @property + def scaled(self): + return self._methodcall(lib.revindex_scaled) + + +class SearchResult(RustObject): + __dealloc_func__ = lib.searchresult_free + + @property + def score(self): + return self._methodcall(lib.searchresult_score) + + @property + def signature(self): + sig_ptr = self._methodcall(lib.searchresult_signature) + return SourmashSignature._from_objptr(sig_ptr) + + @property + def filename(self): + result = decode_str(self._methodcall(lib.searchresult_filename)) + if result == "": + return None + return result diff --git a/src/sourmash/utils.py b/src/sourmash/utils.py index 5790a4fac6..555833d9c1 100644 --- a/src/sourmash/utils.py +++ b/src/sourmash/utils.py @@ -52,7 +52,7 @@ def decode_str(s): def encode_str(s): """Encodes a SourmashStr""" rv = ffi.new("SourmashStr *") - if isinstance(s, text_type): + if isinstance(s, str): s = s.encode("utf-8") rv.data = ffi.from_buffer(s) rv.len = len(s) diff --git a/tests/test_index.py b/tests/test_index.py index 80c08ce1e9..31d1a0ac18 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -14,6 +14,7 @@ from sourmash.index import (LinearIndex, ZipFileLinearIndex, make_jaccard_search_query, CounterGather, LazyLinearIndex, MultiIndex) +from sourmash.index.revindex import RevIndex from sourmash.sbt import SBT, GraphFactory, Leaf from sourmash.sbtmh import SigLeaf from sourmash import sourmash_args @@ -2260,7 +2261,6 @@ def test_lazy_index_wraps_multi_index_location(): lazy2.signatures_with_location()): assert ss_tup == ss_lazy_tup - def test_lazy_loaded_index_1(runtmp): # some basic tests for LazyLoadedIndex lcafile = utils.get_test_data('prot/protein.lca.json.gz') @@ -2338,3 +2338,101 @@ def test_lazy_loaded_index_3_find(runtmp): x = db.search(query, threshold=0.0) x = list(x) assert len(x) == 0 + +def test_revindex_index_search(): + sig2 = utils.get_test_data("2.fa.sig") + sig47 = utils.get_test_data("47.fa.sig") + sig63 = utils.get_test_data("63.fa.sig") + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47) + ss63 = sourmash.load_one_signature(sig63) + + lidx = RevIndex(template=ss2.minhash) + lidx.insert(ss2) + lidx.insert(ss47) + lidx.insert(ss63) + + # now, search for sig2 + sr = lidx.search(ss2, threshold=1.0) + print([s[1].name for s in sr]) + assert len(sr) == 1 + assert sr[0][1] == ss2 + + # search for sig47 with lower threshold; search order not guaranteed. + sr = lidx.search(ss47, threshold=0.1) + print([s[1].name for s in sr]) + assert len(sr) == 2 + sr.sort(key=lambda x: -x[0]) + assert sr[0][1] == ss47 + assert sr[1][1] == ss63 + + # search for sig63 with lower threshold; search order not guaranteed. + sr = lidx.search(ss63, threshold=0.1) + print([s[1].name for s in sr]) + assert len(sr) == 2 + sr.sort(key=lambda x: -x[0]) + assert sr[0][1] == ss63 + assert sr[1][1] == ss47 + + # search for sig63 with high threshold => 1 match + sr = lidx.search(ss63, threshold=0.8) + print([s[1].name for s in sr]) + assert len(sr) == 1 + sr.sort(key=lambda x: -x[0]) + assert sr[0][1] == ss63 + + +def test_revindex_gather(): + sig2 = utils.get_test_data("2.fa.sig") + sig47 = utils.get_test_data("47.fa.sig") + sig63 = utils.get_test_data("63.fa.sig") + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47) + ss63 = sourmash.load_one_signature(sig63) + + lidx = RevIndex(template=ss2.minhash) + lidx.insert(ss2) + lidx.insert(ss47) + lidx.insert(ss63) + + matches = lidx.gather(ss2) + assert len(matches) == 1 + assert matches[0][0] == 1.0 + assert matches[0][1] == ss2 + + matches = lidx.gather(ss47) + assert len(matches) == 1 + assert matches[0][0] == 1.0 + assert matches[0][1] == ss47 + + +def test_revindex_gather_ignore(): + sig2 = utils.get_test_data('2.fa.sig') + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47, ksize=31) + ss63 = sourmash.load_one_signature(sig63, ksize=31) + + # construct an index... + lidx = RevIndex(template=ss2.minhash, signatures=[ss2, ss47, ss63]) + + # ...now search with something that should ignore sig47, the exact match. + search_fn = JaccardSearchBestOnly_ButIgnore([ss47]) + + results = list(lidx.find(search_fn, ss47)) + results = [ ss.signature for ss in results ] + + def is_found(ss, xx): + for q in xx: + print(ss, ss.similarity(q)) + if ss.similarity(q) == 1.0: + return True + return False + + assert not is_found(ss47, results) + assert not is_found(ss2, results) + assert is_found(ss63, results)