diff --git a/Cargo.lock b/Cargo.lock index 930a9cdf53..d7a92d6221 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1419,6 +1419,7 @@ dependencies = [ "cfg-if", "counter", "criterion", + "csv", "finch", "fixedbitset", "flume", diff --git a/Makefile b/Makefile index 4c2ef69abb..990f79c068 100644 --- a/Makefile +++ b/Makefile @@ -28,6 +28,8 @@ include/sourmash.h: src/core/src/lib.rs \ src/core/src/ffi/hyperloglog.rs \ src/core/src/ffi/minhash.rs \ src/core/src/ffi/signature.rs \ + src/core/src/ffi/manifest.rs \ + src/core/src/ffi/picklist.rs \ src/core/src/ffi/nodegraph.rs \ src/core/src/ffi/index/mod.rs \ src/core/src/ffi/index/revindex.rs \ diff --git a/include/sourmash.h b/include/sourmash.h index 1035ed0b46..a36cb326f9 100644 --- a/include/sourmash.h +++ b/include/sourmash.h @@ -42,6 +42,7 @@ enum SourmashErrorCode { SOURMASH_ERROR_CODE_PARSE_INT = 100003, SOURMASH_ERROR_CODE_SERDE_ERROR = 100004, SOURMASH_ERROR_CODE_NIFFLER_ERROR = 100005, + SOURMASH_ERROR_CODE_CSV_ERROR = 100006, }; typedef uint32_t SourmashErrorCode; @@ -51,14 +52,26 @@ typedef struct SourmashHyperLogLog SourmashHyperLogLog; typedef struct SourmashKmerMinHash SourmashKmerMinHash; +typedef struct SourmashLinearIndex SourmashLinearIndex; + +typedef struct SourmashManifest SourmashManifest; + +typedef struct SourmashManifestRowIter SourmashManifestRowIter; + typedef struct SourmashNodegraph SourmashNodegraph; +typedef struct SourmashPicklist SourmashPicklist; + typedef struct SourmashRevIndex SourmashRevIndex; typedef struct SourmashSearchResult SourmashSearchResult; +typedef struct SourmashSelection SourmashSelection; + typedef struct SourmashSignature SourmashSignature; +typedef struct SourmashSignatureIter SourmashSignatureIter; + typedef struct SourmashZipStorage SourmashZipStorage; /** @@ -79,6 +92,15 @@ typedef struct { bool owned; } SourmashStr; +typedef struct { + uint32_t ksize; + uint8_t with_abundance; + SourmashStr md5; + SourmashStr internal_location; + SourmashStr name; + SourmashStr moltype; +} SourmashManifestRow; + bool computeparams_dayhoff(const SourmashComputeParameters *ptr); bool computeparams_dna(const SourmashComputeParameters *ptr); @@ -269,6 +291,32 @@ SourmashKmerMinHash *kmerminhash_to_mutable(const SourmashKmerMinHash *ptr); bool kmerminhash_track_abundance(const SourmashKmerMinHash *ptr); +void linearindex_free(SourmashLinearIndex *ptr); + +uint64_t linearindex_len(const SourmashLinearIndex *ptr); + +SourmashStr linearindex_location(const SourmashLinearIndex *ptr); + +const SourmashManifest *linearindex_manifest(const SourmashLinearIndex *ptr); + +SourmashLinearIndex *linearindex_new(SourmashZipStorage *storage_ptr, + SourmashManifest *manifest_ptr, + SourmashSelection *selection_ptr, + bool use_manifest); + +SourmashLinearIndex *linearindex_select(SourmashLinearIndex *ptr, + const SourmashSelection *selection_ptr); + +void linearindex_set_manifest(SourmashLinearIndex *ptr, SourmashManifest *manifest_ptr); + +SourmashSignatureIter *linearindex_signatures(const SourmashLinearIndex *ptr); + +const SourmashZipStorage *linearindex_storage(const SourmashLinearIndex *ptr); + +SourmashManifestRowIter *manifest_rows(const SourmashManifest *ptr); + +const SourmashManifestRow *manifest_rows_iter_next(SourmashManifestRowIter *ptr); + void nodegraph_buffer_free(uint8_t *ptr, uintptr_t insize); bool nodegraph_count(SourmashNodegraph *ptr, uint64_t h); @@ -313,6 +361,16 @@ SourmashNodegraph *nodegraph_with_tables(uintptr_t ksize, uintptr_t starting_size, uintptr_t n_tables); +void picklist_free(SourmashPicklist *ptr); + +SourmashPicklist *picklist_new(void); + +void picklist_set_coltype(SourmashPicklist *ptr, const char *coltype_ptr, uintptr_t insize); + +void picklist_set_column_name(SourmashPicklist *ptr, const char *prop_ptr, uintptr_t insize); + +void picklist_set_pickfile(SourmashPicklist *ptr, const char *prop_ptr, uintptr_t insize); + void revindex_free(SourmashRevIndex *ptr); const SourmashSearchResult *const *revindex_gather(const SourmashRevIndex *ptr, @@ -358,6 +416,20 @@ double searchresult_score(const SourmashSearchResult *ptr); SourmashSignature *searchresult_signature(const SourmashSearchResult *ptr); +bool selection_abund(const SourmashSelection *ptr); + +uint32_t selection_ksize(const SourmashSelection *ptr); + +HashFunctions selection_moltype(const SourmashSelection *ptr); + +SourmashSelection *selection_new(void); + +void selection_set_abund(SourmashSelection *ptr, bool new_abund); + +void selection_set_ksize(SourmashSelection *ptr, uint32_t new_ksize); + +void selection_set_moltype(SourmashSelection *ptr, HashFunctions new_moltype); + void signature_add_protein(SourmashSignature *ptr, const char *sequence); void signature_add_sequence(SourmashSignature *ptr, const char *sequence, bool force); @@ -388,6 +460,8 @@ void signature_set_mh(SourmashSignature *ptr, const SourmashKmerMinHash *other); void signature_set_name(SourmashSignature *ptr, const char *name); +const SourmashSignature *signatures_iter_next(SourmashSignatureIter *ptr); + SourmashSignature **signatures_load_buffer(const char *ptr, uintptr_t insize, bool _ignore_md5sum, diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 63f0b08146..5f0b04b7f9 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -31,6 +31,7 @@ bytecount = "0.6.0" byteorder = "1.4.3" cfg-if = "1.0" counter = "0.5.7" +csv = "1.1.6" finch = { version = "0.5.0", optional = true } fixedbitset = "0.4.0" getset = "0.1.1" diff --git a/src/core/src/errors.rs b/src/core/src/errors.rs index fc1eb285d0..2becd844bd 100644 --- a/src/core/src/errors.rs +++ b/src/core/src/errors.rs @@ -63,6 +63,9 @@ pub enum SourmashError { #[error(transparent)] IOError(#[from] std::io::Error), + #[error(transparent)] + CsvError(#[from] csv::Error), + #[cfg(not(all(target_arch = "wasm32", target_vendor = "unknown")))] #[error(transparent)] Panic(#[from] crate::ffi::utils::Panic), @@ -108,6 +111,7 @@ pub enum SourmashErrorCode { ParseInt = 100_003, SerdeError = 100_004, NifflerError = 100_005, + CsvError = 100_006, } #[cfg(not(all(target_arch = "wasm32", target_vendor = "unknown")))] @@ -137,6 +141,7 @@ impl SourmashErrorCode { SourmashError::IOError { .. } => SourmashErrorCode::Io, SourmashError::NifflerError { .. } => SourmashErrorCode::NifflerError, SourmashError::Utf8Error { .. } => SourmashErrorCode::Utf8Error, + SourmashError::CsvError { .. } => SourmashErrorCode::CsvError, } } } diff --git a/src/core/src/ffi/index/mod.rs b/src/core/src/ffi/index/mod.rs index 932a97b222..77eca58936 100644 --- a/src/core/src/ffi/index/mod.rs +++ b/src/core/src/ffi/index/mod.rs @@ -1,5 +1,8 @@ pub mod revindex; +use crate::encodings::HashFunctions; +use crate::index::{Selection, SigStore}; + use crate::signature::Signature; use crate::ffi::signature::SourmashSignature; @@ -35,3 +38,91 @@ pub unsafe extern "C" fn searchresult_signature( let result = SourmashSearchResult::as_rust(ptr); SourmashSignature::from_rust((result.1).clone()) } + +//================================================================ + +pub struct SourmashSelection; + +impl ForeignObject for SourmashSelection { + type RustObject = Selection; +} + +#[no_mangle] +pub unsafe extern "C" fn selection_new() -> *mut SourmashSelection { + SourmashSelection::from_rust(Selection::default()) +} + +#[no_mangle] +pub unsafe extern "C" fn selection_ksize(ptr: *const SourmashSelection) -> u32 { + let sel = SourmashSelection::as_rust(ptr); + if let Some(ksize) = sel.ksize() { + ksize + } else { + todo!("empty ksize case not supported yet") + } +} + +#[no_mangle] +pub unsafe extern "C" fn selection_set_ksize(ptr: *mut SourmashSelection, new_ksize: u32) { + let sel = SourmashSelection::as_rust_mut(ptr); + sel.set_ksize(new_ksize); +} + +#[no_mangle] +pub unsafe extern "C" fn selection_abund(ptr: *const SourmashSelection) -> bool { + let sel = SourmashSelection::as_rust(ptr); + if let Some(abund) = sel.abund() { + abund + } else { + todo!("empty abund case not supported yet") + } +} + +#[no_mangle] +pub unsafe extern "C" fn selection_set_abund(ptr: *mut SourmashSelection, new_abund: bool) { + let sel = SourmashSelection::as_rust_mut(ptr); + sel.set_abund(new_abund); +} + +#[no_mangle] +pub unsafe extern "C" fn selection_moltype(ptr: *const SourmashSelection) -> HashFunctions { + let sel = SourmashSelection::as_rust(ptr); + if let Some(hash_function) = sel.moltype() { + hash_function + } else { + todo!("empty hash_function case not supported yet") + } +} + +#[no_mangle] +pub unsafe extern "C" fn selection_set_moltype( + ptr: *mut SourmashSelection, + new_moltype: HashFunctions, +) { + let sel = SourmashSelection::as_rust_mut(ptr); + sel.set_moltype(new_moltype); +} + +//================================================================ +// +pub struct SignatureIterator { + iter: Box>, +} + +pub struct SourmashSignatureIter; + +impl ForeignObject for SourmashSignatureIter { + type RustObject = SignatureIterator; +} + +#[no_mangle] +pub unsafe extern "C" fn signatures_iter_next( + ptr: *mut SourmashSignatureIter, +) -> *const SourmashSignature { + let iterator = SourmashSignatureIter::as_rust_mut(ptr); + + match iterator.iter.next() { + Some(sig) => SourmashSignature::from_rust(sig.into()), + None => std::ptr::null(), + } +} diff --git a/src/core/src/ffi/index/revindex.rs b/src/core/src/ffi/index/revindex.rs index cecfafbc51..abf0bc6bad 100644 --- a/src/core/src/ffi/index/revindex.rs +++ b/src/core/src/ffi/index/revindex.rs @@ -1,15 +1,22 @@ use std::path::PathBuf; use std::slice; +use std::sync::Arc; -use crate::index::revindex::mem_revindex::RevIndex; +use crate::index::revindex::mem_revindex::{LinearRevIndex, RevIndex}; use crate::index::Index; +use crate::manifest::Manifest; use crate::signature::{Signature, SigsTrait}; -use crate::sketch::minhash::KmerMinHash; +use crate::sketch::minhash::{max_hash_for_scaled, KmerMinHash}; use crate::sketch::Sketch; +use crate::storage::Storage; -use crate::ffi::index::SourmashSearchResult; +use crate::ffi::index::{ + SignatureIterator, SourmashSearchResult, SourmashSelection, SourmashSignatureIter, +}; +use crate::ffi::manifest::SourmashManifest; use crate::ffi::minhash::{MinHash, SourmashKmerMinHash}; use crate::ffi::signature::SourmashSignature; +use crate::ffi::storage::SourmashZipStorage; use crate::ffi::utils::{ForeignObject, SourmashStr}; pub struct SourmashRevIndex; @@ -253,3 +260,141 @@ unsafe fn revindex_signatures( Ok(Box::into_raw(b) as *mut *mut SourmashSignature) } } + +//-------------------------------------------------- + +pub struct SourmashLinearIndex; + +impl ForeignObject for SourmashLinearIndex { + type RustObject = LinearRevIndex; +} + +ffi_fn! { +unsafe fn linearindex_new( + storage_ptr: *mut SourmashZipStorage, + manifest_ptr: *mut SourmashManifest, + selection_ptr: *mut SourmashSelection, + use_manifest: bool, +) -> Result<*mut SourmashLinearIndex> { + let storage = Arc::try_unwrap(*SourmashZipStorage::into_rust(storage_ptr)).ok().unwrap(); + + let manifest = if manifest_ptr.is_null() { + if use_manifest { + // Load manifest from zipstorage + Some(Manifest::from_reader(storage.load("SOURMASH-MANIFEST.csv")?.as_slice())?) + } else { + None + } + } else { + Some(*SourmashManifest::into_rust(manifest_ptr)) + }; + + let _selection = if !selection_ptr.is_null() { + Some(SourmashSelection::into_rust(selection_ptr)) + } else { + None + }; + // TODO: how to extract a template? Probably from selection? + let max_hash = max_hash_for_scaled(100); + let template = Sketch::MinHash( + KmerMinHash::builder() + .num(0u32) + .ksize(57) + .hash_function(crate::encodings::HashFunctions::murmur64_protein) + .max_hash(max_hash) + .build(), + ); + + /* + def __init__(self, storage, *, selection_dict=None, + traverse_yield_all=False, manifest=None, use_manifest=True): + sig_files: Manifest, + template: &Sketch, + keep_sigs: bool, + ref_sigs: Option>, + storage: Option, + */ + + let linear_index = LinearRevIndex::new(manifest, &template, false, None, Some(storage)); + + Ok(SourmashLinearIndex::from_rust(linear_index)) +} +} + +#[no_mangle] +pub unsafe extern "C" fn linearindex_free(ptr: *mut SourmashLinearIndex) { + SourmashLinearIndex::drop(ptr); +} + +#[no_mangle] +pub unsafe extern "C" fn linearindex_manifest( + ptr: *const SourmashLinearIndex, +) -> *const SourmashManifest { + let index = SourmashLinearIndex::as_rust(ptr); + SourmashManifest::from_rust(index.manifest()) +} + +ffi_fn! { +unsafe fn linearindex_set_manifest( + ptr: *mut SourmashLinearIndex, + manifest_ptr: *mut SourmashManifest, +) -> Result<()> { + let index = SourmashLinearIndex::as_rust_mut(ptr); + let manifest = SourmashManifest::into_rust(manifest_ptr); + + index.set_manifest(*manifest)?; + Ok(()) +} +} + +#[no_mangle] +pub unsafe extern "C" fn linearindex_len(ptr: *const SourmashLinearIndex) -> u64 { + let index = SourmashLinearIndex::as_rust(ptr); + index.len() as u64 +} + +#[no_mangle] +pub unsafe extern "C" fn linearindex_location(ptr: *const SourmashLinearIndex) -> SourmashStr { + let index = SourmashLinearIndex::as_rust(ptr); + match index.location() { + Some(x) => x, + None => "".into(), + } + .into() +} + +#[no_mangle] +pub unsafe extern "C" fn linearindex_storage( + ptr: *const SourmashLinearIndex, +) -> *const SourmashZipStorage { + let index = SourmashLinearIndex::as_rust(ptr); + let storage = index.storage(); + + match storage { + Some(st) => SourmashZipStorage::from_rust(st), + None => std::ptr::null::(), + } +} + +#[no_mangle] +pub unsafe extern "C" fn linearindex_signatures( + ptr: *const SourmashLinearIndex, +) -> *mut SourmashSignatureIter { + let index = SourmashLinearIndex::as_rust(ptr); + + let iter = Box::new(index.signatures_iter()); + SourmashSignatureIter::from_rust(SignatureIterator { iter }) +} + +ffi_fn! { +unsafe fn linearindex_select( + ptr: *mut SourmashLinearIndex, + selection_ptr: *const SourmashSelection, +) -> Result<*mut SourmashLinearIndex> { + let index = SourmashLinearIndex::into_rust(ptr); + let selection = SourmashSelection::as_rust(selection_ptr); + + let new_index = index.select(selection)?; + Ok(SourmashLinearIndex::from_rust(new_index)) +} +} diff --git a/src/core/src/ffi/manifest.rs b/src/core/src/ffi/manifest.rs new file mode 100644 index 0000000000..815f8d83f1 --- /dev/null +++ b/src/core/src/ffi/manifest.rs @@ -0,0 +1,73 @@ +use crate::manifest::{Manifest, Record}; + +use crate::ffi::utils::{ForeignObject, SourmashStr}; + +pub struct SourmashManifest; + +impl ForeignObject for SourmashManifest { + type RustObject = Manifest; +} + +pub struct ManifestRowIterator { + iter: Box>, +} + +pub struct SourmashManifestRowIter; + +impl ForeignObject for SourmashManifestRowIter { + type RustObject = ManifestRowIterator; +} + +#[no_mangle] +pub unsafe extern "C" fn manifest_rows_iter_next( + ptr: *mut SourmashManifestRowIter, +) -> *const SourmashManifestRow { + let iterator = SourmashManifestRowIter::as_rust_mut(ptr); + + match iterator.iter.next() { + Some(row) => SourmashManifestRow::from_rust(row.into()), + None => std::ptr::null(), + } +} + +#[no_mangle] +pub unsafe extern "C" fn manifest_rows( + ptr: *const SourmashManifest, +) -> *mut SourmashManifestRowIter { + let manifest = SourmashManifest::as_rust(ptr); + + let iter = Box::new(manifest.iter()); + SourmashManifestRowIter::from_rust(ManifestRowIterator { iter }) +} + +#[repr(C)] +pub struct SourmashManifestRow { + pub ksize: u32, + pub with_abundance: u8, + pub md5: SourmashStr, + pub internal_location: SourmashStr, + pub name: SourmashStr, + pub moltype: SourmashStr, +} + +impl ForeignObject for SourmashManifestRow { + type RustObject = SourmashManifestRow; +} + +impl From<&Record> for SourmashManifestRow { + fn from(record: &Record) -> SourmashManifestRow { + Self { + ksize: record.ksize(), + with_abundance: record.with_abundance() as u8, + md5: record.md5().into(), + name: record.name().into(), + moltype: record.moltype().to_string().into(), + internal_location: record + .internal_location() + .to_str() + .unwrap() + .to_owned() + .into(), + } + } +} diff --git a/src/core/src/ffi/mod.rs b/src/core/src/ffi/mod.rs index a67de37176..44e856001f 100644 --- a/src/core/src/ffi/mod.rs +++ b/src/core/src/ffi/mod.rs @@ -9,8 +9,10 @@ pub mod utils; pub mod cmd; pub mod hyperloglog; pub mod index; +pub mod manifest; pub mod minhash; pub mod nodegraph; +pub mod picklist; pub mod signature; pub mod storage; diff --git a/src/core/src/ffi/picklist.rs b/src/core/src/ffi/picklist.rs new file mode 100644 index 0000000000..5a2f9cc63b --- /dev/null +++ b/src/core/src/ffi/picklist.rs @@ -0,0 +1,76 @@ +use std::os::raw::c_char; +use std::slice; + +use crate::picklist::Picklist; + +use crate::ffi::utils::ForeignObject; + +pub struct SourmashPicklist; + +impl ForeignObject for SourmashPicklist { + type RustObject = Picklist; +} + +#[no_mangle] +pub unsafe extern "C" fn picklist_new() -> *mut SourmashPicklist { + SourmashPicklist::from_rust(Picklist::default()) +} + +#[no_mangle] +pub unsafe extern "C" fn picklist_free(ptr: *mut SourmashPicklist) { + SourmashPicklist::drop(ptr); +} + +ffi_fn! { +unsafe fn picklist_set_coltype( + ptr: *mut SourmashPicklist, + coltype_ptr: *const c_char, + insize: usize, +) -> Result<()> { + let coltype = { + assert!(!coltype_ptr.is_null()); + let coltype = slice::from_raw_parts(coltype_ptr as *mut u8, insize); + std::str::from_utf8(coltype)? + }; + let pl = SourmashPicklist::as_rust_mut(ptr); + pl.set_coltype(coltype.to_string()); + + Ok(()) +} +} + +ffi_fn! { +unsafe fn picklist_set_pickfile( + ptr: *mut SourmashPicklist, + prop_ptr: *const c_char, + insize: usize, +) -> Result<()> { + let prop = { + assert!(!prop_ptr.is_null()); + let prop = slice::from_raw_parts(prop_ptr as *mut u8, insize); + std::str::from_utf8(prop)? + }; + let pl = SourmashPicklist::as_rust_mut(ptr); + pl.set_pickfile(prop.to_string()); + + Ok(()) +} +} + +ffi_fn! { +unsafe fn picklist_set_column_name( + ptr: *mut SourmashPicklist, + prop_ptr: *const c_char, + insize: usize, +) -> Result<()> { + let prop = { + assert!(!prop_ptr.is_null()); + let prop = slice::from_raw_parts(prop_ptr as *mut u8, insize); + std::str::from_utf8(prop)? + }; + let pl = SourmashPicklist::as_rust_mut(ptr); + pl.set_column_name(prop.to_string()); + + Ok(()) +} +} diff --git a/src/core/src/ffi/storage.rs b/src/core/src/ffi/storage.rs index 98eca095b2..e8abcf1d51 100644 --- a/src/core/src/ffi/storage.rs +++ b/src/core/src/ffi/storage.rs @@ -1,5 +1,6 @@ use std::os::raw::c_char; use std::slice; +use std::sync::Arc; use crate::ffi::utils::{ForeignObject, SourmashStr}; use crate::prelude::*; @@ -8,7 +9,7 @@ use crate::storage::ZipStorage; pub struct SourmashZipStorage; impl ForeignObject for SourmashZipStorage { - type RustObject = ZipStorage; + type RustObject = Arc; } ffi_fn! { @@ -20,7 +21,7 @@ unsafe fn zipstorage_new(ptr: *const c_char, insize: usize) -> Result<*mut Sourm }; let zipstorage = ZipStorage::from_file(path)?; - Ok(SourmashZipStorage::from_rust(zipstorage)) + Ok(SourmashZipStorage::from_rust(Arc::new(zipstorage))) } } @@ -110,7 +111,7 @@ unsafe fn zipstorage_set_subdir( std::str::from_utf8(path)? }; - storage.set_subdir(path.to_string()); + (*Arc::get_mut(storage).unwrap()).set_subdir(path.to_string()); Ok(()) } } diff --git a/src/core/src/index/linear.rs b/src/core/src/index/linear.rs index 78b2c6f1f5..20656a62e4 100644 --- a/src/core/src/index/linear.rs +++ b/src/core/src/index/linear.rs @@ -12,12 +12,12 @@ use crate::storage::{FSStorage, InnerStorage, Storage, StorageInfo}; use crate::Error; #[derive(TypedBuilder)] -pub struct LinearIndex { +pub struct LinearIndex { #[builder(default)] storage: Option, #[builder(default)] - datasets: Vec>, + datasets: Vec, } #[derive(Serialize, Deserialize)] @@ -27,15 +27,11 @@ struct LinearInfo { leaves: Vec, } -impl<'a, L> Index<'a> for LinearIndex -where - L: Clone + Comparable + 'a, - SigStore: From, -{ - type Item = L; +impl<'a> Index<'a> for LinearIndex { + type Item = Signature; //type SignatureIterator = std::slice::Iter<'a, Self::Item>; - fn insert(&mut self, node: L) -> Result<(), Error> { + fn insert(&mut self, node: Self::Item) -> Result<(), Error> { self.datasets.push(node.into()); Ok(()) } @@ -76,11 +72,7 @@ where */ } -impl LinearIndex -where - L: ToWriter, - SigStore: ReadData, -{ +impl LinearIndex { pub fn save_file>( &mut self, path: P, @@ -115,7 +107,7 @@ where .iter_mut() .map(|l| { // Trigger data loading - let _: &L = (*l).data().unwrap(); + let _: &Signature = (*l).data().unwrap(); // set storage to new one l.storage = Some(storage.clone()); @@ -137,7 +129,7 @@ where Ok(()) } - pub fn from_path>(path: P) -> Result, Error> { + pub fn from_path>(path: P) -> Result { let file = File::open(&path)?; let mut reader = BufReader::new(file); @@ -147,11 +139,11 @@ where basepath.push(path); basepath.canonicalize()?; - let linear = LinearIndex::::from_reader(&mut reader, basepath.parent().unwrap())?; + let linear = LinearIndex::from_reader(&mut reader, basepath.parent().unwrap())?; Ok(linear) } - pub fn from_reader(rdr: R, path: P) -> Result, Error> + pub fn from_reader(rdr: R, path: P) -> Result where R: Read, P: AsRef, @@ -171,7 +163,7 @@ where .leaves .into_iter() .map(|l| { - let mut v: SigStore = l.into(); + let mut v: SigStore = l.into(); v.storage = Some(storage.clone()); v }) diff --git a/src/core/src/index/mod.rs b/src/core/src/index/mod.rs index 832fdf9091..6d5d87103b 100644 --- a/src/core/src/index/mod.rs +++ b/src/core/src/index/mod.rs @@ -15,6 +15,7 @@ use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; +use crate::encodings::HashFunctions; use crate::errors::ReadDataError; use crate::index::search::{search_minhashes, search_minhashes_containment}; use crate::prelude::*; @@ -23,6 +24,39 @@ use crate::sketch::Sketch; use crate::storage::{InnerStorage, Storage}; use crate::Error; +#[derive(Default)] +pub struct Selection { + ksize: Option, + abund: Option, + moltype: Option, +} + +impl Selection { + pub fn ksize(&self) -> Option { + self.ksize + } + + pub fn set_ksize(&mut self, ksize: u32) { + self.ksize = Some(ksize); + } + + pub fn abund(&self) -> Option { + self.abund + } + + pub fn set_abund(&mut self, value: bool) { + self.abund = Some(value); + } + + pub fn moltype(&self) -> Option { + self.moltype + } + + pub fn set_moltype(&mut self, value: HashFunctions) { + self.moltype = Some(value); + } +} + pub trait Index<'a> { type Item: Comparable; //type SignatureIterator: Iterator; @@ -116,7 +150,7 @@ pub struct DatasetInfo { } #[derive(TypedBuilder, Default, Clone)] -pub struct SigStore { +pub struct SigStore { #[builder(setter(into))] filename: String, @@ -129,16 +163,16 @@ pub struct SigStore { storage: Option, #[builder(setter(into), default)] - data: OnceCell, + data: OnceCell, } -impl SigStore { +impl SigStore { pub fn name(&self) -> String { self.name.clone() } } -impl std::fmt::Debug for SigStore { +impl std::fmt::Debug for SigStore { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, @@ -148,7 +182,7 @@ impl std::fmt::Debug for SigStore { } } -impl ReadData for SigStore { +impl ReadData for SigStore { fn data(&self) -> Result<&Signature, Error> { if let Some(sig) = self.data.get() { Ok(sig) @@ -172,10 +206,7 @@ impl ReadData for SigStore { } } -impl SigStore -where - T: ToWriter, -{ +impl SigStore { pub fn save(&self, path: &str) -> Result { if let Some(storage) = &self.storage { if let Some(data) = self.data.get() { @@ -192,8 +223,8 @@ where } } -impl SigStore { - pub fn count_common(&self, other: &SigStore) -> u64 { +impl SigStore { + pub fn count_common(&self, other: &SigStore) -> u64 { let ng: &Signature = self.data().unwrap(); let ong: &Signature = other.data().unwrap(); @@ -220,13 +251,13 @@ impl SigStore { } } -impl From> for Signature { - fn from(other: SigStore) -> Signature { +impl From for Signature { + fn from(other: SigStore) -> Signature { other.data.get().unwrap().to_owned() } } -impl Deref for SigStore { +impl Deref for SigStore { type Target = Signature; fn deref(&self) -> &Signature { @@ -234,8 +265,8 @@ impl Deref for SigStore { } } -impl From for SigStore { - fn from(other: Signature) -> SigStore { +impl From for SigStore { + fn from(other: Signature) -> SigStore { let name = other.name(); let filename = other.filename(); @@ -249,8 +280,8 @@ impl From for SigStore { } } -impl Comparable> for SigStore { - fn similarity(&self, other: &SigStore) -> f64 { +impl Comparable for SigStore { + fn similarity(&self, other: &SigStore) -> f64 { let ng: &Signature = self.data().unwrap(); let ong: &Signature = other.data().unwrap(); @@ -273,7 +304,7 @@ impl Comparable> for SigStore { unimplemented!() } - fn containment(&self, other: &SigStore) -> f64 { + fn containment(&self, other: &SigStore) -> f64 { let ng: &Signature = self.data().unwrap(); let ong: &Signature = other.data().unwrap(); @@ -325,8 +356,8 @@ impl Comparable for Signature { } } -impl From for SigStore { - fn from(other: DatasetInfo) -> SigStore { +impl From for SigStore { + fn from(other: DatasetInfo) -> SigStore { SigStore { filename: other.filename, name: other.name, diff --git a/src/core/src/index/revindex/mem_revindex.rs b/src/core/src/index/revindex/mem_revindex.rs index 69bbc4de7d..1d089fcefe 100644 --- a/src/core/src/index/revindex/mem_revindex.rs +++ b/src/core/src/index/revindex/mem_revindex.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; use getset::{CopyGetters, Getters, Setters}; use log::{debug, info}; @@ -12,10 +13,12 @@ use typed_builder::TypedBuilder; use rayon::prelude::*; use crate::encodings::{Color, Colors, Idx}; -use crate::index::Index; +use crate::index::{Index, Selection, SigStore}; +use crate::manifest::Manifest; use crate::signature::{Signature, SigsTrait}; use crate::sketch::minhash::{KmerMinHash, MinHashOps}; use crate::sketch::Sketch; +use crate::storage::{Storage, ZipStorage}; use crate::Error; use crate::HashIntoType; @@ -100,17 +103,524 @@ impl HashToColor { // https://davidkoloski.me/rkyv/ #[derive(Serialize, Deserialize)] pub struct RevIndex { + linear: LinearRevIndex, hash_to_color: HashToColor, + colors: Colors, +} - sig_files: Vec, +#[derive(Serialize, Deserialize)] +pub struct LinearRevIndex { + sig_files: Manifest, #[serde(skip)] - ref_sigs: Option>, + ref_sigs: Option>, template: Sketch, - colors: Colors, - //#[serde(skip)] - //storage: Option, + + #[serde(skip)] + storage: Option>, +} + +impl LinearRevIndex { + pub fn new( + sig_files: Option, + template: &Sketch, + keep_sigs: bool, + ref_sigs: Option>, + storage: Option, + ) -> Self { + if ref_sigs.is_none() && sig_files.is_none() { + todo!("throw error, one need to be set"); + } + + let ref_sigs = if let Some(ref_sigs) = ref_sigs { + Some(ref_sigs.into_iter().map(|m| m.into()).collect()) + } else if keep_sigs { + let search_sigs: Vec<_> = sig_files + .as_ref() + .unwrap() + .internal_locations() + .map(PathBuf::from) + .collect(); + + #[cfg(feature = "parallel")] + let sigs_iter = search_sigs.par_iter(); + + #[cfg(not(feature = "parallel"))] + let sigs_iter = search_sigs.iter(); + + Some( + sigs_iter + .map(|ref_path| { + if let Some(storage) = &storage { + let sig_data = storage + .load(ref_path.to_str().unwrap_or_else(|| { + panic!("error converting path {:?}", ref_path) + })) + .unwrap_or_else(|_| panic!("error loading {:?}", ref_path)); + Signature::from_reader(sig_data.as_slice()) + .unwrap_or_else(|_| panic!("Error processing {:?}", ref_path)) + .swap_remove(0) + .into() + } else { + Signature::from_path(&ref_path) + .unwrap_or_else(|_| panic!("Error processing {:?}", ref_path)) + .swap_remove(0) + .into() + } + }) + .collect(), + ) + } else { + None + }; + + let storage = storage.map(Arc::new); + + LinearRevIndex { + sig_files: sig_files.unwrap(), + template: template.clone(), + ref_sigs, + storage, + } + } + + fn index( + self, + threshold: usize, + merged_query: Option, + queries: Option<&[KmerMinHash]>, + ) -> RevIndex { + let processed_sigs = AtomicUsize::new(0); + + let search_sigs: Vec<_> = self + .sig_files + .internal_locations() + .map(PathBuf::from) + .collect(); + + #[cfg(feature = "parallel")] + let sig_iter = search_sigs.par_iter(); + + #[cfg(not(feature = "parallel"))] + let sig_iter = search_sigs.iter(); + + let filtered_sigs = sig_iter.enumerate().filter_map(|(dataset_id, filename)| { + let i = processed_sigs.fetch_add(1, Ordering::SeqCst); + if i % 1000 == 0 { + info!("Processed {} reference sigs", i); + } + + let search_sig = if let Some(storage) = &self.storage { + let sig_data = storage + .load( + filename + .to_str() + .unwrap_or_else(|| panic!("error converting path {:?}", filename)), + ) + .unwrap_or_else(|_| panic!("error loading {:?}", filename)); + + Signature::from_reader(sig_data.as_slice()) + } else { + Signature::from_path(&filename) + } + .unwrap_or_else(|_| panic!("Error processing {:?}", filename)) + .swap_remove(0); + + RevIndex::map_hashes_colors( + dataset_id, + &search_sig, + queries, + &merged_query, + threshold, + &self.template, + ) + }); + + #[cfg(feature = "parallel")] + let (hash_to_color, colors) = filtered_sigs.reduce( + || (HashToColor::new(), Colors::default()), + HashToColor::reduce_hashes_colors, + ); + + #[cfg(not(feature = "parallel"))] + let (hash_to_color, colors) = filtered_sigs.fold( + (HashToColor::new(), Colors::default()), + HashToColor::reduce_hashes_colors, + ); + + RevIndex { + hash_to_color, + colors, + linear: self, + } + } + + pub fn location(&self) -> Option { + if let Some(storage) = &self.storage { + storage.path() + } else { + None + } + } + + pub fn storage(&self) -> Option> { + self.storage.clone() + } + + pub fn select(mut self, selection: &Selection) -> Result { + let manifest = self.sig_files.select_to_manifest(selection)?; + self.sig_files = manifest; + + Ok(self) + /* + # if we have a manifest, run 'select' on the manifest. + manifest = self.manifest + traverse_yield_all = self.traverse_yield_all + + if manifest is not None: + manifest = manifest.select_to_manifest(**kwargs) + return ZipFileLinearIndex(self.storage, + selection_dict=None, + traverse_yield_all=traverse_yield_all, + manifest=manifest, + use_manifest=True) + else: + # no manifest? just pass along all the selection kwargs to + # the new ZipFileLinearIndex. + + assert manifest is None + if self.selection_dict: + # combine selects... + d = dict(self.selection_dict) + for k, v in kwargs.items(): + if k in d: + if d[k] is not None and d[k] != v: + raise ValueError(f"incompatible select on '{k}'") + d[k] = v + kwargs = d + + return ZipFileLinearIndex(self.storage, + selection_dict=kwargs, + traverse_yield_all=traverse_yield_all, + manifest=None, + use_manifest=False) + */ + } + + pub fn counter_for_query(&self, query: &KmerMinHash) -> SigCounter { + let processed_sigs = AtomicUsize::new(0); + + // TODO: Some(ref_sigs) case + + let search_sigs: Vec<_> = self + .sig_files + .internal_locations() + .map(PathBuf::from) + .collect(); + + #[cfg(feature = "parallel")] + let sig_iter = search_sigs.par_iter(); + + #[cfg(not(feature = "parallel"))] + let sig_iter = search_sigs.iter(); + + let counters = sig_iter.enumerate().filter_map(|(dataset_id, filename)| { + let i = processed_sigs.fetch_add(1, Ordering::SeqCst); + if i % 1000 == 0 { + info!("Processed {} reference sigs", i); + } + + let search_sig = if let Some(storage) = &self.storage { + let sig_data = storage + .load( + filename + .to_str() + .unwrap_or_else(|| panic!("error converting path {:?}", filename)), + ) + .unwrap_or_else(|_| panic!("error loading {:?}", filename)); + + Signature::from_reader(sig_data.as_slice()) + } else { + Signature::from_path(&filename) + } + .unwrap_or_else(|_| panic!("Error processing {:?}", filename)) + .swap_remove(0); + + let mut search_mh = None; + if let Some(Sketch::MinHash(mh)) = search_sig.select_sketch(&self.template) { + search_mh = Some(mh); + }; + let search_mh = search_mh.expect("Couldn't find a compatible MinHash"); + + let (large_mh, small_mh) = if query.size() > search_mh.size() { + (query, search_mh) + } else { + (search_mh, query) + }; + + let (size, _) = small_mh + .intersection_size(large_mh) + .unwrap_or_else(|_| panic!("error computing intersection for {:?}", filename)); + + if size == 0 { + None + } else { + let mut counter: SigCounter = Default::default(); + counter[&(dataset_id as u64)] += size as usize; + Some(counter) + } + }); + + let reduce_counters = |mut a: SigCounter, b: SigCounter| { + a.extend(&b); + a + }; + + #[cfg(feature = "parallel")] + let counter = counters.reduce(|| SigCounter::new(), reduce_counters); + + #[cfg(not(feature = "parallel"))] + let counter = counters.fold(SigCounter::new(), reduce_counters); + + counter + } + + pub fn search( + &self, + counter: SigCounter, + similarity: bool, + threshold: usize, + ) -> Result, Box> { + let mut matches = vec![]; + if similarity { + unimplemented!("TODO: threshold correction") + } + + for (dataset_id, size) in counter.most_common() { + if size >= threshold { + matches.push( + self.sig_files[dataset_id as usize] + .internal_location() + .to_str() + .unwrap() + .into(), + ); + } else { + break; + }; + } + Ok(matches) + } + + fn gather_round( + &self, + dataset_id: u64, + match_size: usize, + query: &KmerMinHash, + round: usize, + ) -> Result { + let match_path = if self.sig_files.is_empty() { + PathBuf::new() + } else { + self.sig_files[dataset_id as usize].internal_location() + }; + let match_sig = self.sig_for_dataset(dataset_id as usize)?; + let result = self.stats_for_match(&match_sig, query, match_size, match_path, round)?; + Ok(result) + } + + fn sig_for_dataset(&self, dataset_id: usize) -> Result { + let match_path = if self.sig_files.is_empty() { + PathBuf::new() + } else { + self.sig_files[dataset_id as usize].internal_location() + }; + + let match_sig = if let Some(refsigs) = &self.ref_sigs { + refsigs[dataset_id as usize].clone() + } else { + let mut sig = if let Some(storage) = &self.storage { + let sig_data = storage + .load( + match_path + .to_str() + .unwrap_or_else(|| panic!("error converting path {:?}", match_path)), + ) + .unwrap_or_else(|_| panic!("error loading {:?}", match_path)); + Signature::from_reader(sig_data.as_slice())? + } else { + Signature::from_path(&match_path)? + }; + // TODO: remove swap_remove + sig.swap_remove(0).into() + }; + Ok(match_sig) + } + + fn stats_for_match( + &self, + match_sig: &Signature, + query: &KmerMinHash, + match_size: usize, + match_path: PathBuf, + gather_result_rank: usize, + ) -> Result { + let mut match_mh = None; + if let Some(Sketch::MinHash(mh)) = match_sig.select_sketch(&self.template) { + match_mh = Some(mh); + } + let match_mh = match_mh.expect("Couldn't find a compatible MinHash"); + + // Calculate stats + let f_orig_query = match_size as f64 / query.size() as f64; + let f_match = match_size as f64 / match_mh.size() as f64; + let filename = match_path.to_str().unwrap().into(); + let name = match_sig.name(); + let unique_intersect_bp = match_mh.scaled() as usize * match_size; + + let (intersect_orig, _) = match_mh.intersection_size(query)?; + let intersect_bp = (match_mh.scaled() as u64 * intersect_orig) as usize; + + let f_unique_to_query = intersect_orig as f64 / query.size() as f64; + let match_ = match_sig.clone(); + + // TODO: all of these + let f_unique_weighted = 0.; + let average_abund = 0; + let median_abund = 0; + let std_abund = 0; + let md5 = "".into(); + let f_match_orig = 0.; + let remaining_bp = 0; + + Ok(GatherResult { + intersect_bp, + f_orig_query, + f_match, + f_unique_to_query, + f_unique_weighted, + average_abund, + median_abund, + std_abund, + filename, + name, + md5, + match_, + f_match_orig, + unique_intersect_bp, + gather_result_rank, + remaining_bp, + }) + } + + pub fn gather( + &self, + mut counter: SigCounter, + threshold: usize, + query: &KmerMinHash, + ) -> Result, Box> { + let mut match_size = usize::max_value(); + let mut matches = vec![]; + + while match_size > threshold && !counter.is_empty() { + let (dataset_id, size) = counter.most_common()[0]; + if threshold == 0 && size == 0 { + break; + } + + match_size = if size >= threshold { + size + } else { + break; + }; + + let result = self.gather_round(dataset_id, match_size, query, matches.len())?; + + // Prepare counter for finding the next match by decrementing + // all hashes found in the current match in other datasets + // TODO: maybe par_iter? + let mut to_remove: HashSet = Default::default(); + to_remove.insert(dataset_id); + + for (dataset, value) in counter.iter_mut() { + let dataset_sig = self.sig_for_dataset(*dataset as usize)?; + let mut match_mh = None; + if let Some(Sketch::MinHash(mh)) = dataset_sig.select_sketch(&self.template) { + match_mh = Some(mh); + } + let match_mh = match_mh.expect("Couldn't find a compatible MinHash"); + + let (intersection, _) = query.intersection_size(match_mh)?; + if intersection as usize > *value { + to_remove.insert(*dataset); + } else { + *value -= intersection as usize; + }; + } + to_remove.iter().for_each(|dataset_id| { + counter.remove(dataset_id); + }); + matches.push(result); + } + Ok(matches) + } + + pub fn manifest(&self) -> Manifest { + self.sig_files.clone() + } + + pub fn set_manifest(&mut self, new_manifest: Manifest) -> Result<(), Error> { + self.sig_files = new_manifest; + Ok(()) + } + + pub fn signatures_iter(&self) -> impl Iterator + '_ { + if let Some(_sigs) = &self.ref_sigs { + //sigs.iter().cloned() + todo!("this works, but need to match return types") + } else { + // FIXME temp solution, must find better one! + (0..self.sig_files.len()) + .map(move |dataset_id| self.sig_for_dataset(dataset_id).expect("error loading sig")) + } + } +} + +impl<'a> Index<'a> for LinearRevIndex { + type Item = SigStore; + + fn insert(&mut self, _node: Self::Item) -> Result<(), Error> { + unimplemented!() + } + + fn save>(&self, _path: P) -> Result<(), Error> { + unimplemented!() + } + + fn load>(_path: P) -> Result<(), Error> { + unimplemented!() + } + + fn len(&self) -> usize { + if let Some(refs) = &self.ref_sigs { + refs.len() + } else { + self.sig_files.len() + } + } + + fn signatures(&self) -> Vec { + if let Some(ref sigs) = self.ref_sigs { + sigs.to_vec() + } else { + unimplemented!() + } + } + + fn signature_refs(&self) -> Vec<&Self::Item> { + unimplemented!() + } } impl RevIndex { @@ -165,75 +675,33 @@ impl RevIndex { // If threshold is zero, let's merge all queries and save time later let merged_query = queries.and_then(|qs| Self::merge_queries(qs, threshold)); - let processed_sigs = AtomicUsize::new(0); - - #[cfg(feature = "parallel")] - let sig_iter = search_sigs.par_iter(); - - #[cfg(not(feature = "parallel"))] - let sig_iter = search_sigs.iter(); - - let filtered_sigs = sig_iter.enumerate().filter_map(|(dataset_id, filename)| { - let i = processed_sigs.fetch_add(1, Ordering::SeqCst); - if i % 1000 == 0 { - info!("Processed {} reference sigs", i); - } - - let search_sig = Signature::from_path(filename) - .unwrap_or_else(|_| panic!("Error processing {:?}", filename)) - .swap_remove(0); + let linear = LinearRevIndex::new(Some(search_sigs.into()), template, keep_sigs, None, None); + linear.index(threshold, merged_query, queries) + } - RevIndex::map_hashes_colors( - dataset_id, - &search_sig, - queries, - &merged_query, - threshold, - template, - ) - }); + pub fn from_zipstorage( + storage: ZipStorage, + template: &Sketch, + threshold: usize, + queries: Option<&[KmerMinHash]>, + keep_sigs: bool, + ) -> Result { + // If threshold is zero, let's merge all queries and save time later + let merged_query = queries.and_then(|qs| Self::merge_queries(qs, threshold)); - #[cfg(feature = "parallel")] - let (hash_to_color, colors) = filtered_sigs.reduce( - || (HashToColor::new(), Colors::default()), - HashToColor::reduce_hashes_colors, - ); + // Load manifest from zipstorage + let manifest = Manifest::from_reader(storage.load("SOURMASH-MANIFEST.csv")?.as_slice())?; + let search_sigs: Vec<_> = manifest.internal_locations().map(PathBuf::from).collect(); - #[cfg(not(feature = "parallel"))] - let (hash_to_color, colors) = filtered_sigs.fold( - (HashToColor::new(), Colors::default()), - HashToColor::reduce_hashes_colors, + let linear = LinearRevIndex::new( + Some(search_sigs.as_slice().into()), + template, + keep_sigs, + None, + Some(storage), ); - // TODO: build this together with hash_to_idx? - let ref_sigs = if keep_sigs { - #[cfg(feature = "parallel")] - let sigs_iter = search_sigs.par_iter(); - - #[cfg(not(feature = "parallel"))] - let sigs_iter = search_sigs.iter(); - - Some( - sigs_iter - .map(|ref_path| { - Signature::from_path(ref_path) - .unwrap_or_else(|_| panic!("Error processing {:?}", ref_path)) - .swap_remove(0) - }) - .collect(), - ) - } else { - None - }; - - RevIndex { - hash_to_color, - sig_files: search_sigs.into(), - ref_sigs, - template: template.clone(), - colors, - // storage: Some(InnerStorage::new(MemStorage::default())), - } + Ok(linear.index(threshold, merged_query, queries)) } fn merge_queries(qs: &[KmerMinHash], threshold: usize) -> Option { @@ -257,49 +725,15 @@ impl RevIndex { // If threshold is zero, let's merge all queries and save time later let merged_query = queries.and_then(|qs| Self::merge_queries(qs, threshold)); - let processed_sigs = AtomicUsize::new(0); - - #[cfg(feature = "parallel")] - let sigs_iter = search_sigs.par_iter(); - #[cfg(not(feature = "parallel"))] - let sigs_iter = search_sigs.iter(); - - let filtered_sigs = sigs_iter.enumerate().filter_map(|(dataset_id, sig)| { - let i = processed_sigs.fetch_add(1, Ordering::SeqCst); - if i % 1000 == 0 { - info!("Processed {} reference sigs", i); - } - - RevIndex::map_hashes_colors( - dataset_id, - sig, - queries, - &merged_query, - threshold, - template, - ) - }); - - #[cfg(feature = "parallel")] - let (hash_to_color, colors) = filtered_sigs.reduce( - || (HashToColor::new(), Colors::default()), - HashToColor::reduce_hashes_colors, + let linear = LinearRevIndex::new( + Default::default(), + template, + false, + search_sigs.into(), + None, ); - #[cfg(not(feature = "parallel"))] - let (hash_to_color, colors) = filtered_sigs.fold( - (HashToColor::new(), Colors::default()), - HashToColor::reduce_hashes_colors, - ); - - RevIndex { - hash_to_color, - sig_files: vec![], - ref_sigs: search_sigs.into(), - template: template.clone(), - colors, - //storage: None, - } + linear.index(threshold, merged_query, queries) } fn map_hashes_colors( @@ -348,25 +782,22 @@ impl RevIndex { } } + pub fn counter_for_query(&self, query: &KmerMinHash) -> SigCounter { + query + .iter_mins() + .filter_map(|hash| self.hash_to_color.get(hash)) + .flat_map(|color| self.colors.indices(color)) + .cloned() + .collect() + } + pub fn search( &self, counter: SigCounter, similarity: bool, threshold: usize, ) -> Result, Box> { - let mut matches = vec![]; - if similarity { - unimplemented!("TODO: threshold correction") - } - - for (dataset_id, size) in counter.most_common() { - if size >= threshold { - matches.push(self.sig_files[dataset_id as usize].to_str().unwrap().into()); - } else { - break; - }; - } - Ok(matches) + self.linear.search(counter, similarity, threshold) } pub fn gather( @@ -381,102 +812,30 @@ impl RevIndex { while match_size > threshold && !counter.is_empty() { let (dataset_id, size) = counter.most_common()[0]; match_size = if size >= threshold { size } else { break }; - - let p; - let match_path = if self.sig_files.is_empty() { - p = PathBuf::new(); // TODO: Fix somehow? - &p - } else { - &self.sig_files[dataset_id as usize] - }; - - let ref_match; - let match_sig = if let Some(refsigs) = &self.ref_sigs { - &refsigs[dataset_id as usize] - } else { - // TODO: remove swap_remove - ref_match = Signature::from_path(match_path)?.swap_remove(0); - &ref_match - }; - - let mut match_mh = None; - if let Some(Sketch::MinHash(mh)) = match_sig.select_sketch(&self.template) { - match_mh = Some(mh); - } - let match_mh = match_mh.expect("Couldn't find a compatible MinHash"); - - // Calculate stats - let f_orig_query = match_size as f64 / query.size() as f64; - let f_match = match_size as f64 / match_mh.size() as f64; - let filename = match_path.to_str().unwrap().into(); - let name = match_sig.name(); - let unique_intersect_bp = match_mh.scaled() as usize * match_size; - let gather_result_rank = matches.len(); - - let (intersect_orig, _) = match_mh.intersection_size(query)?; - let intersect_bp = (match_mh.scaled() * intersect_orig) as usize; - - let f_unique_to_query = intersect_orig as f64 / query.size() as f64; - let match_ = match_sig.clone(); - - // TODO: all of these - let f_unique_weighted = 0.; - let average_abund = 0; - let median_abund = 0; - let std_abund = 0; - let md5 = "".into(); - let f_match_orig = 0.; - let remaining_bp = 0; - - let result = GatherResult { - intersect_bp, - f_orig_query, - f_match, - f_unique_to_query, - f_unique_weighted, - average_abund, - median_abund, - std_abund, - filename, - name, - md5, - match_, - f_match_orig, - unique_intersect_bp, - gather_result_rank, - remaining_bp, - }; - matches.push(result); - - // Prepare counter for finding the next match by decrementing - // all hashes found in the current match in other datasets - for hash in match_mh.iter_mins() { - if let Some(color) = self.hash_to_color.get(hash) { - for dataset in self.colors.indices(color) { - counter.entry(*dataset).and_modify(|e| { - if *e > 0 { - *e -= 1 - } - }); + let result = self + .linear + .gather_round(dataset_id, match_size, query, matches.len())?; + if let Some(Sketch::MinHash(match_mh)) = + result.match_.select_sketch(&self.linear.template) + { + // Prepare counter for finding the next match by decrementing + // all hashes found in the current match in other datasets + for hash in match_mh.iter_mins() { + if let Some(color) = self.hash_to_color.get(hash) { + counter.subtract(self.colors.indices(color).cloned()); } } + counter.remove(&dataset_id); + matches.push(result); + } else { + unimplemented!() } - counter.remove(&dataset_id); } Ok(matches) } - pub fn counter_for_query(&self, query: &KmerMinHash) -> SigCounter { - query - .iter_mins() - .filter_map(|hash| self.hash_to_color.get(hash)) - .flat_map(|color| self.colors.indices(color)) - .cloned() - .collect() - } - pub fn template(&self) -> Sketch { - self.template.clone() + self.linear.template.clone() } // TODO: mh should be a sketch, or even a sig... @@ -527,25 +886,34 @@ impl RevIndex { for (dataset_id, size) in counter.most_common() { let match_size = if size >= threshold { size } else { break }; - let p; - let match_path = if self.sig_files.is_empty() { - p = PathBuf::new(); // TODO: Fix somehow? - &p + let match_path = if self.linear.sig_files.is_empty() { + PathBuf::new() } else { - &self.sig_files[dataset_id as usize] + self.linear.sig_files[dataset_id as usize].internal_location() }; let ref_match; - let match_sig = if let Some(refsigs) = &self.ref_sigs { + let match_sig = if let Some(refsigs) = &self.linear.ref_sigs { &refsigs[dataset_id as usize] } else { + let mut sig = if let Some(storage) = &self.linear.storage { + let sig_data = + storage + .load(match_path.to_str().unwrap_or_else(|| { + panic!("error converting path {:?}", match_path) + })) + .unwrap_or_else(|_| panic!("error loading {:?}", match_path)); + Signature::from_reader(sig_data.as_slice())? + } else { + Signature::from_path(&match_path)? + }; // TODO: remove swap_remove - ref_match = Signature::from_path(match_path)?.swap_remove(0); + ref_match = sig.swap_remove(0); &ref_match }; let mut match_mh = None; - if let Some(Sketch::MinHash(mh)) = match_sig.select_sketch(&self.template) { + if let Some(Sketch::MinHash(mh)) = match_sig.select_sketch(&self.linear.template) { match_mh = Some(mh); } let match_mh = match_mh.unwrap(); @@ -569,7 +937,7 @@ impl RevIndex { } } -#[derive(TypedBuilder, CopyGetters, Getters, Setters, Serialize, Deserialize, Debug)] +#[derive(TypedBuilder, CopyGetters, Getters, Setters, Serialize, Deserialize, Debug, PartialEq)] pub struct GatherResult { #[getset(get_copy = "pub")] intersect_bp: usize, @@ -594,7 +962,6 @@ pub struct GatherResult { #[getset(get = "pub")] md5: String, - match_: Signature, f_match_orig: f64, unique_intersect_bp: usize, @@ -624,16 +991,16 @@ impl<'a> Index<'a> for RevIndex { } fn len(&self) -> usize { - if let Some(refs) = &self.ref_sigs { + if let Some(refs) = &self.linear.ref_sigs { refs.len() } else { - self.sig_files.len() + self.linear.sig_files.len() } } fn signatures(&self) -> Vec { - if let Some(ref sigs) = self.ref_sigs { - sigs.to_vec() + if let Some(ref sigs) = self.linear.ref_sigs { + sigs.iter().map(|s| s.clone().into()).collect() } else { unimplemented!() } @@ -699,4 +1066,49 @@ mod test { //assert_eq!(index.colors.len(), 3); assert_eq!(index.colors.len(), 7); } + + #[test] + fn revindex_from_zipstorage() { + let max_hash = max_hash_for_scaled(100); + let template = Sketch::MinHash( + KmerMinHash::builder() + .num(0u32) + .ksize(57) + .hash_function(crate::encodings::HashFunctions::murmur64_protein) + .max_hash(max_hash) + .build(), + ); + let storage = ZipStorage::from_file("../../tests/test-data/prot/protein.zip") + .expect("error loading zipfile"); + let index = RevIndex::from_zipstorage(storage, &template, 0, None, false) + .expect("error building from ziptorage"); + + assert_eq!(index.colors.len(), 3); + + let query_sig = Signature::from_path( + "../../tests/test-data/prot/protein/GCA_001593925.1_ASM159392v1_protein.faa.gz.sig", + ) + .expect("Error processing query") + .swap_remove(0); + let mut query_mh = None; + if let Some(Sketch::MinHash(mh)) = query_sig.select_sketch(&template) { + query_mh = Some(mh); + } + let query_mh = query_mh.expect("Couldn't find a compatible MinHash"); + + let counter_rev = index.counter_for_query(query_mh); + let counter_lin = index.linear.counter_for_query(query_mh); + + let results_rev = index.search(counter_rev, false, 0).unwrap(); + let results_linear = index.linear.search(counter_lin, false, 0).unwrap(); + assert_eq!(results_rev, results_linear); + + let counter_rev = index.counter_for_query(query_mh); + let counter_lin = index.linear.counter_for_query(query_mh); + + let results_rev = index.gather(counter_rev, 0, query_mh).unwrap(); + let results_linear = index.linear.gather(counter_lin, 0, query_mh).unwrap(); + assert_eq!(results_rev.len(), 1); + assert_eq!(results_rev, results_linear); + } } diff --git a/src/core/src/lib.rs b/src/core/src/lib.rs index 14a25ab632..a52750535a 100644 --- a/src/core/src/lib.rs +++ b/src/core/src/lib.rs @@ -26,6 +26,8 @@ pub mod prelude; pub mod cmd; +pub mod manifest; +pub mod picklist; pub mod signature; pub mod sketch; pub mod storage; diff --git a/src/core/src/manifest.rs b/src/core/src/manifest.rs new file mode 100644 index 0000000000..ce740c638b --- /dev/null +++ b/src/core/src/manifest.rs @@ -0,0 +1,186 @@ +use std::convert::TryInto; +use std::io::Read; +use std::ops::Deref; +use std::path::PathBuf; + +use serde::de; +use serde::{Deserialize, Serialize}; + +use crate::encodings::HashFunctions; +use crate::index::Selection; +use crate::Error; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Record { + internal_location: String, + ksize: u32, + + #[serde(deserialize_with = "to_bool")] + with_abundance: bool, + + md5: String, + name: String, + moltype: String, + /* + md5short: String, + num: String, + scaled: String, + n_hashes: String, + filename: String, + */ +} + +fn to_bool<'de, D>(deserializer: D) -> Result +where + D: de::Deserializer<'de>, +{ + match String::deserialize(deserializer)? + .to_ascii_lowercase() + .as_ref() + { + "0" | "false" => Ok(false), + "1" | "true" => Ok(true), + other => Err(de::Error::invalid_value( + de::Unexpected::Str(other), + &"0/1 or true/false are the only supported values", + )), + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct Manifest { + records: Vec, +} + +impl Record { + pub fn internal_location(&self) -> PathBuf { + self.internal_location.clone().into() + } + + pub fn ksize(&self) -> u32 { + self.ksize + } + + pub fn with_abundance(&self) -> bool { + self.with_abundance + } + + pub fn md5(&self) -> &str { + self.md5.as_ref() + } + + pub fn name(&self) -> &str { + self.name.as_ref() + } + + pub fn moltype(&self) -> HashFunctions { + self.moltype.as_str().try_into().unwrap() + } +} + +impl Manifest { + pub fn from_reader(rdr: R) -> Result { + let mut records = vec![]; + + let mut rdr = csv::ReaderBuilder::new() + .comment(Some(b'#')) + .from_reader(rdr); + for result in rdr.deserialize() { + let record: Record = result?; + records.push(record); + } + Ok(Manifest { records }) + } + + pub fn internal_locations(&self) -> impl Iterator { + self.records.iter().map(|r| r.internal_location.as_str()) + } + + pub fn iter(&self) -> impl Iterator { + self.records.iter() + } + + pub fn select_to_manifest(&self, selection: &Selection) -> Result { + let rows = self.records.iter().filter(|row| { + let mut valid = true; + valid = if let Some(ksize) = selection.ksize() { + row.ksize == ksize + } else { + valid + }; + valid = if let Some(abund) = selection.abund() { + valid && row.with_abundance() == abund + } else { + valid + }; + valid = if let Some(moltype) = selection.moltype() { + valid && row.moltype() == moltype + } else { + valid + }; + valid + }); + + Ok(Manifest { + records: rows.cloned().collect(), + }) + + /* + matching_rows = self.rows + if ksize: + matching_rows = ( row for row in matching_rows + if row['ksize'] == ksize ) + if moltype: + matching_rows = ( row for row in matching_rows + if row['moltype'] == moltype ) + if scaled or containment: + if containment and not scaled: + raise ValueError("'containment' requires 'scaled' in Index.select'") + + matching_rows = ( row for row in matching_rows + if row['scaled'] and not row['num'] ) + if num: + matching_rows = ( row for row in matching_rows + if row['num'] and not row['scaled'] ) + + if abund: + # only need to concern ourselves if abundance is _required_ + matching_rows = ( row for row in matching_rows + if row['with_abundance'] ) + + if picklist: + matching_rows = ( row for row in matching_rows + if picklist.matches_manifest_row(row) ) + + # return only the internal filenames! + for row in matching_rows: + yield row + */ + } +} + +impl From<&[PathBuf]> for Manifest { + fn from(v: &[PathBuf]) -> Self { + Manifest { + records: v + .iter() + .map(|p| Record { + internal_location: p.to_str().unwrap().into(), + ksize: 0, // FIXME + with_abundance: false, // FIXME + md5: "".into(), // FIXME + name: "".into(), // FIXME + moltype: "".into(), // FIXME + }) + .collect(), + } + } +} + +impl Deref for Manifest { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.records + } +} diff --git a/src/core/src/picklist.rs b/src/core/src/picklist.rs new file mode 100644 index 0000000000..0b77cb4457 --- /dev/null +++ b/src/core/src/picklist.rs @@ -0,0 +1,17 @@ +use getset::{CopyGetters, Getters, Setters}; +use typed_builder::TypedBuilder; + +#[derive(Default, TypedBuilder, CopyGetters, Getters, Setters)] +pub struct Picklist { + #[getset(get = "pub", set = "pub")] + #[builder(default = "".into())] + coltype: String, + + #[getset(get = "pub", set = "pub")] + #[builder(default = "".into())] + pickfile: String, + + #[getset(get = "pub", set = "pub")] + #[builder(default = "".into())] + column_name: String, +} diff --git a/src/core/src/storage.rs b/src/core/src/storage.rs index a2269bd6e7..51738e4ccc 100644 --- a/src/core/src/storage.rs +++ b/src/core/src/storage.rs @@ -24,10 +24,10 @@ pub trait Storage { } #[derive(Clone)] -pub struct InnerStorage(Arc>); +pub struct InnerStorage(Arc>); impl InnerStorage { - pub fn new(inner: impl Storage + 'static) -> InnerStorage { + pub fn new(inner: impl Storage + Send + Sync + 'static) -> InnerStorage { InnerStorage(Arc::new(Mutex::new(inner))) } } diff --git a/src/core/tests/storage.rs b/src/core/tests/storage.rs index 5a60e02fcc..a27fa27b14 100644 --- a/src/core/tests/storage.rs +++ b/src/core/tests/storage.rs @@ -42,3 +42,41 @@ fn zipstorage_list_sbts() -> Result<(), Box> { Ok(()) } + +#[cfg(feature = "parallel")] +#[test] +fn zipstorage_parallel_access() -> Result<(), Box> { + use std::io::BufReader; + + use rayon::prelude::*; + use sourmash::signature::{Signature, SigsTrait}; + use sourmash::sketch::minhash::KmerMinHash; + + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/v6.sbt.zip"); + + let zs = ZipStorage::from_file(filename.to_str().unwrap())?; + + let total_hashes: usize = [ + ".sbt.v3/f71e78178af9e45e6f1d87a0c53c465c", + ".sbt.v3/f0c834bc306651d2b9321fb21d3e8d8f", + ".sbt.v3/4e94e60265e04f0763142e20b52c0da1", + ".sbt.v3/6d6e87e1154e95b279e5e7db414bc37b", + ".sbt.v3/0107d767a345eff67ecdaed2ee5cd7ba", + ".sbt.v3/b59473c94ff2889eca5d7165936e64b3", + ".sbt.v3/60f7e23c24a8d94791cc7a8680c493f9", + ] + .par_iter() + .map(|path| { + let data = zs.load(path).unwrap(); + let sigs: Vec = serde_json::from_reader(&data[..]).expect("Loading error"); + sigs.iter() + .map(|v| v.sketches().iter().map(|mh| mh.size()).sum::()) + .sum::() + }) + .sum(); + + assert_eq!(total_hashes, 3500); + + Ok(()) +} diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index 08068255e5..4d1bba2356 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -34,10 +34,15 @@ CounterGather - an ancillary class returned by the 'counter_gather()' method. """ +from __future__ import annotations + import os import sourmash from abc import abstractmethod, ABC -from collections import namedtuple, Counter +from collections import Counter +from collections import defaultdict +from typing import NamedTuple, Optional, TypedDict, TYPE_CHECKING +import weakref from sourmash.search import (make_jaccard_search_query, make_containment_query, @@ -45,12 +50,79 @@ from sourmash.manifest import CollectionManifest from sourmash.logging import debug_literal from sourmash.signature import load_signatures, save_signatures +from sourmash._lowlevel import ffi, lib +from sourmash.utils import RustObject, rustcall, decode_str, encode_str +from sourmash import SourmashSignature +from sourmash.picklist import SignaturePicklist from sourmash.minhash import (flatten_and_downsample_scaled, flatten_and_downsample_num, flatten_and_intersect_scaled) -# generic return tuple for Index.search and Index.gather -IndexSearchResult = namedtuple('Result', 'score, signature, location') + +if TYPE_CHECKING: + from typing_extensions import Unpack + + +class IndexSearchResult(NamedTuple): + """generic return tuple for Index.search and Index.gather""" + score: float + signature: SourmashSignature + location: str + + +class Selection(TypedDict): + ksize: Optional[int] + moltype: Optional[str] + num: Optional[int] + scaled: Optional[int] + containment: Optional[bool] + abund: Optional[bool] + picklist: Optional[SignaturePicklist] + + +# TypedDict can't have methods (it is a dict in runtime) +def _selection_as_rust(selection: Selection): + ptr = lib.selection_new() + + for key, v in selection.items(): + if v is not None: + if key == "ksize": + rustcall(lib.selection_set_ksize, ptr, v) + + elif key == "moltype": + hash_function = None + if v.lower() == "dna": + hash_function = lib.HASH_FUNCTIONS_MURMUR64_DNA + elif v.lower() == "protein": + hash_function = lib.HASH_FUNCTIONS_MURMUR64_PROTEIN + elif v.lower() == "dayhoff": + hash_function = lib.HASH_FUNCTIONS_MURMUR64_DAYHOFF + elif v.lower() == "hp": + hash_function = lib.HASH_FUNCTIONS_MURMUR64_HP + + rustcall(lib.selection_set_moltype, ptr, hash_function) + + elif key == "num": + raise NotImplementedError("num") + + elif key == "scaled": + raise NotImplementedError("scaled") + + elif key == "containment": + raise NotImplementedError("containment") + + elif key == "abund": + rustcall(lib.selection_set_abund, ptr, bool(v)) + + elif key == "picklist": + picklist_ptr = v._as_rust() + rustcall(lib.selection_set_picklist, ptr, picklist_ptr) + + else: + raise KeyError(f"Unsupported key {key} for Selection in rust") + + return ptr + class Index(ABC): # this will be removed soon; see sourmash#1894. @@ -307,8 +379,7 @@ def counter_gather(self, query, threshold_bp, **kwargs): return counter @abstractmethod - def select(self, ksize=None, moltype=None, scaled=None, num=None, - abund=None, containment=None): + def select(self, **kwargs: Unpack[Selection]): """Return Index containing only signatures that match requirements. Current arguments can be any or all of: @@ -326,9 +397,16 @@ def select(self, ksize=None, moltype=None, scaled=None, num=None, """ -def select_signature(ss, *, ksize=None, moltype=None, scaled=0, num=0, - containment=False, abund=None, picklist=None): +def select_signature(ss, **kwargs: Unpack[Selection]): "Check that the given signature matches the specified requirements." + ksize = kwargs.get('ksize') + moltype = kwargs.get('moltype') + containment = kwargs.get('containment', False) + scaled = kwargs.get('scaled', 0) + num = kwargs.get('num', 0) + abund = kwargs.get('abund') + picklist = kwargs.get('picklist') + # ksize match? if ksize and ksize != ss.minhash.ksize: return False @@ -408,7 +486,7 @@ def load(cls, location, filename=None): lidx = LinearIndex(si, filename=filename) return lidx - def select(self, **kwargs): + def select(self, **kwargs: Unpack[Selection]): """Return new LinearIndex containing only signatures that match req's. Does not raise ValueError, but may return an empty Index. @@ -479,7 +557,7 @@ def save(self, path): def load(cls, path): raise NotImplementedError - def select(self, **kwargs): + def select(self, **kwargs: Unpack[Selection]): """Return new object yielding only signatures that match req's. Does not raise ValueError, but may return an empty Index. @@ -642,7 +720,7 @@ def signatures(self): if select(ss): yield ss - def select(self, **kwargs): + def select(self, **kwargs: Unpack[Selection]): "Select signatures in zip file based on ksize/moltype/etc." # if we have a manifest, run 'select' on the manifest. @@ -1053,7 +1131,7 @@ def load_from_pathlist(cls, filename): def save(self, *args): raise NotImplementedError - def select(self, **kwargs): + def select(self, **kwargs: Unpack[Selection]): "Run 'select' on the manifest." new_manifest = self.manifest.select_to_manifest(**kwargs) return MultiIndex(new_manifest, self.parent, @@ -1162,8 +1240,135 @@ def save(self, *args): def insert(self, *args): raise NotImplementedError - def select(self, **kwargs): + def select(self, **kwargs: Unpack[Selection]): "Run 'select' on the manifest." new_manifest = self.manifest.select_to_manifest(**kwargs) return StandaloneManifestIndex(new_manifest, self._location, prefix=self.prefix) + +class RustLinearIndex(Index, RustObject): + """\ + A read-only collection of signatures in a zip file. + + Does not support `insert` or `save`. + + Concrete class; signatures dynamically loaded from disk; uses manifests. + """ + is_database = True + + __dealloc_func__ = lib.linearindex_free + + def __init__(self, storage, *, selection_dict=None, + traverse_yield_all=False, manifest=None, use_manifest=True): + + self._selection_dict = selection_dict + self._traverse_yield_all = traverse_yield_all + self._use_manifest = use_manifest + + # Taking ownership of the storage + storage_ptr = storage._take_objptr() + + manifest_ptr = ffi.NULL + # do we have a manifest already? if not, try loading. + if use_manifest: + if manifest is not None: + debug_literal('RustLinearIndex using passed-in manifest') + manifest_ptr = manifest._as_rust()._take_objptr() + + selection_ptr = ffi.NULL + + self._objptr = rustcall(lib.linearindex_new, storage_ptr, + manifest_ptr, selection_ptr, use_manifest) + + """ + if self.manifest is not None: + assert not self.selection_dict, self.selection_dict + if self.selection_dict: + assert self.manifest is None + """ + + @property + def manifest(self): + return CollectionManifest._from_rust(self._methodcall(lib.linearindex_manifest)) + + @manifest.setter + def manifest(self, value): + if value is None: + return # FIXME: can't unset manifest in a Rust Linear Index + self._methodcall(lib.linearindex_set_manifest, value._as_rust()._take_objptr()) + + def __bool__(self): + "Are there any matching signatures in this zipfile? Avoid calling len." + return self._methodcall(lib.linearindex_len) > 0 + + def __len__(self): + "calculate number of signatures." + return self._methodcall(lib.linearindex_len) + + @property + def location(self): + return decode_str(self._methodcall(lib.linearindex_location)) + + @property + def storage(self): + from ..sbt_storage import ZipStorage + + ptr = self._methodcall(lib.linearindex_storage) + return ZipStorage._from_objptr(ptr) + + def insert(self, signature): + raise NotImplementedError + + def save(self, path): + raise NotImplementedError + + @classmethod + def load(cls, location, traverse_yield_all=False, use_manifest=True): + "Class method to load a zipfile." + from ..sbt_storage import ZipStorage + + # we can only load from existing zipfiles in this method. + if not os.path.exists(location): + raise FileNotFoundError(location) + + storage = ZipStorage(location) + return cls(storage, traverse_yield_all=traverse_yield_all, + use_manifest=use_manifest) + + def _signatures_with_internal(self): + """Return an iterator of tuples (ss, internal_location). + + Note: does not limit signatures to subsets. + """ + # list all the files, without using the Storage interface; currently, + # 'Storage' does not provide a way to list all the files, so :shrug:. + for filename in self.storage._filenames(): + # should we load this file? if it ends in .sig OR we are forcing: + if filename.endswith('.sig') or \ + filename.endswith('.sig.gz') or \ + self._traverse_yield_all: + sig_data = self.storage.load(filename) + for ss in load_signatures(sig_data): + yield ss, filename + + def signatures(self): + "Load all signatures in the zip file." + attached_refs = weakref.WeakKeyDictionary() + iterator = self._methodcall(lib.linearindex_signatures) + + next_sig = rustcall(lib.signatures_iter_next, iterator) + while next_sig != ffi.NULL: + attached_refs[next_sig] = iterator + yield SourmashSignature._from_objptr(next_sig) + next_sig = rustcall(lib.signatures_iter_next, iterator) + + def select(self, **kwargs: Unpack[Selection]): + "Select signatures in zip file based on ksize/moltype/etc." + + selection = _selection_as_rust(kwargs) + + # select consumes the current index + ptr = self._take_objptr() + ptr = rustcall(lib.linearindex_select, ptr, selection) + + return RustLinearIndex._from_objptr(ptr) diff --git a/src/sourmash/manifest.py b/src/sourmash/manifest.py index bfd27eabb9..d2f78563cb 100644 --- a/src/sourmash/manifest.py +++ b/src/sourmash/manifest.py @@ -7,9 +7,13 @@ import os.path from abc import abstractmethod import itertools +from typing import TYPE_CHECKING from sourmash.picklist import SignaturePicklist +if TYPE_CHECKING: + from typing_extensions import Unpack + class BaseCollectionManifest: """ @@ -303,6 +307,7 @@ def _select(self, *, ksize=None, moltype=None, scaled=0, num=0, for row in matching_rows: yield row + #def select_to_manifest(self, **kwargs: Unpack[Selection]): def select_to_manifest(self, **kwargs): "Do a 'select' and return a new CollectionManifest object." new_rows = self._select(**kwargs) @@ -343,3 +348,34 @@ def to_picklist(self): picklist.pickset = set(self._md5_set) return picklist + + @staticmethod + def _from_rust(value): + from ._lowlevel import ffi, lib + from .utils import rustcall, decode_str + + iterator = rustcall(lib.manifest_rows, value) + + rows = [] + next_row = rustcall(lib.manifest_rows_iter_next, iterator) + while next_row != ffi.NULL: + + # TODO: extract row data from next_row + # FIXME: free mem from strings? + row = {} + row['md5'] = decode_str(next_row.md5) + row['md5short'] = row['md5'][:8] + row['ksize'] = next_row.ksize + row['moltype'] = decode_str(next_row.moltype) + row['num'] = 0 #ss.minhash.num + row['scaled'] = 0 #ss.minhash.scaled + row['n_hashes'] = 0 # len(ss.minhash) + row['with_abundance'] = next_row.with_abundance + row['name'] = decode_str(next_row.name) + row['filename'] = "" #ss.filename + row['internal_location'] = decode_str(next_row.internal_location) + rows.append(row) + + next_row = rustcall(lib.manifest_rows_iter_next, iterator) + + return CollectionManifest(rows) diff --git a/src/sourmash/picklist.py b/src/sourmash/picklist.py index 30d5c84f90..b2764ffce7 100644 --- a/src/sourmash/picklist.py +++ b/src/sourmash/picklist.py @@ -252,6 +252,24 @@ def filter(self, it): if self.__contains__(ss): yield ss + def _as_rust(self): + from ._lowlevel import ffi, lib + from .utils import rustcall, decode_str + + ptr = lib.picklist_new() + + rustcall(lib.picklist_set_coltype, ptr, self.coltype) + rustcall(lib.picklist_set_pickfile, ptr, self.pickfile) + rustcall(lib.picklist_set_column_name, ptr, self.column_name) + rustcall(lib.picklist_set_pickstyle, ptr, self.pickstyle) + + #self.preprocess_fn = preprocess[coltype] + #self.pickset = None + #self.found = set() + #self.n_queries = 0 + + return ptr + def passes_all_picklists(ss, picklists): "does the signature 'ss' pass all of the picklists?" diff --git a/src/sourmash/utils.py b/src/sourmash/utils.py index 71afc20261..acb4b73d7a 100644 --- a/src/sourmash/utils.py +++ b/src/sourmash/utils.py @@ -29,6 +29,13 @@ def _get_objptr(self): raise RuntimeError("Object is closed") return self._objptr + def _take_objptr(self): + if not self._objptr: + raise RuntimeError("Object is closed") + ret = self._objptr + self._objptr = None + return ret + def __del__(self): if self._objptr is None or self._shared: return