diff --git a/src/sourmash/sbt.py b/src/sourmash/sbt.py index 9cae16b693..62eb604c18 100644 --- a/src/sourmash/sbt.py +++ b/src/sourmash/sbt.py @@ -183,6 +183,7 @@ def __init__(self, factory, *, d=2, storage=None, cache_size=None): if cache_size is None: cache_size = sys.maxsize self._nodescache = _NodesCache(maxsize=cache_size) + self._location = None def signatures(self): for k in self.leaves(): @@ -389,7 +390,7 @@ def search(self, query, *args, **kwargs): # tree search should always/only return matches above threshold assert similarity >= threshold - results.append((similarity, leaf.data, None)) + results.append((similarity, leaf.data, self.location)) return results @@ -435,7 +436,7 @@ def gather(self, query, *args, **kwargs): containment = query.minhash.contained_by(leaf_mh, True) assert containment >= threshold, "containment {} not below threshold {}".format(containment, threshold) - results.append((containment, leaf.data, None)) + results.append((containment, leaf.data, self.location)) results.sort(key=lambda x: -x[0]) @@ -758,7 +759,9 @@ def load(cls, location, *, leaf_loader=None, storage=None, print_version_warning elif storage is None: storage = klass(**jnodes['storage']['args']) - return loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size) + obj = loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size) + obj.location = location + return obj @staticmethod def _load_v1(jnodes, leaf_loader, dirname, storage, *, print_version_warning=True, cache_size=None): diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 557fb4689d..272c3e9352 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -29,12 +29,13 @@ def search_databases(query, databases, threshold, do_containment, best_only, ignore_abundance, unload_data=False): results = [] found_md5 = set() - for (obj, filename, filetype) in databases: - search_iter = obj.search(query, threshold=threshold, + for db in databases: + search_iter = db.search(query, threshold=threshold, do_containment=do_containment, ignore_abundance=ignore_abundance, best_only=best_only, unload_data=unload_data) + for (similarity, match, filename) in search_iter: md5 = match.md5sum() if md5 not in found_md5: @@ -84,8 +85,8 @@ def _find_best(dblist, query, threshold_bp): threshold_bp = int(threshold_bp / query_scaled) * query_scaled # search across all databases - for (obj, filename, filetype) in dblist: - for cont, match, fname in obj.gather(query, threshold_bp=threshold_bp): + for db in dblist: + for cont, match, fname in db.gather(query, threshold_bp=threshold_bp): assert cont # all matches should be nonzero. # note, break ties based on name, to ensure consistent order. @@ -94,9 +95,7 @@ def _find_best(dblist, query, threshold_bp): # update best match. best_cont = cont best_match = match - - # some objects may not have associated filename (e.g. SBTs) - best_filename = fname or filename + best_filename = fname if not best_match: return None, None, None diff --git a/src/sourmash/sourmash_args.py b/src/sourmash/sourmash_args.py index 258b99cee0..aaf1536c7d 100644 --- a/src/sourmash/sourmash_args.py +++ b/src/sourmash/sourmash_args.py @@ -256,6 +256,8 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None) Load one or more SBTs, LCAs, and/or signatures. Check for compatibility with query. + + This is basically a user-focused wrapping of _load_databases. """ query_ksize = query.minhash.ksize query_moltype = get_moltype(query) @@ -281,7 +283,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None) siglist = _select_sigs(db, moltype=query_moltype, ksize=query_ksize) siglist = filter_compatible_signatures(query, siglist, 1) linear = LinearIndex(siglist, filename=filename) - databases.append((linear, filename, False)) + databases.append(linear) n_signatures += len(linear) @@ -291,7 +293,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None) is_similarity_query): sys.exit(-1) - databases.append((db, filename, 'SBT')) + databases.append(db) notify('loaded SBT {}', filename, end='\r') n_databases += 1 @@ -304,7 +306,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None) notify('loaded LCA {}', filename, end='\r') n_databases += 1 - databases.append((db, filename, 'LCA')) + databases.append(db) # signature file elif dbtype == DatabaseType.SIGLIST: @@ -316,7 +318,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None) sys.exit(-1) linear = LinearIndex(siglist, filename=filename) - databases.append((linear, filename, 'signature')) + databases.append(linear) notify('loaded {} signatures from {}', len(linear), filename, end='\r')