From d7e306412200ce64a3a00801c0d692ec539965a5 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 6 Mar 2021 11:24:38 -0800 Subject: [PATCH] add 'location' to SBT objects --- src/sourmash/sbt.py | 9 ++++++--- src/sourmash/search.py | 6 ++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/sourmash/sbt.py b/src/sourmash/sbt.py index 9cae16b693..15c6f06c7a 100644 --- a/src/sourmash/sbt.py +++ b/src/sourmash/sbt.py @@ -183,6 +183,7 @@ def __init__(self, factory, *, d=2, storage=None, cache_size=None): if cache_size is None: cache_size = sys.maxsize self._nodescache = _NodesCache(maxsize=cache_size) + self.location = None def signatures(self): for k in self.leaves(): @@ -389,7 +390,7 @@ def search(self, query, *args, **kwargs): # tree search should always/only return matches above threshold assert similarity >= threshold - results.append((similarity, leaf.data, None)) + results.append((similarity, leaf.data, self.location)) return results @@ -435,7 +436,7 @@ def gather(self, query, *args, **kwargs): containment = query.minhash.contained_by(leaf_mh, True) assert containment >= threshold, "containment {} not below threshold {}".format(containment, threshold) - results.append((containment, leaf.data, None)) + results.append((containment, leaf.data, self.location)) results.sort(key=lambda x: -x[0]) @@ -758,7 +759,9 @@ def load(cls, location, *, leaf_loader=None, storage=None, print_version_warning elif storage is None: storage = klass(**jnodes['storage']['args']) - return loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size) + obj = loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size) + obj.location = location + return obj @staticmethod def _load_v1(jnodes, leaf_loader, dirname, storage, *, print_version_warning=True, cache_size=None): diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 08e83bc559..045d567f22 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -85,7 +85,7 @@ def _find_best(dblist, query, threshold_bp): threshold_bp = int(threshold_bp / query_scaled) * query_scaled # search across all databases - for (obj, filename) in dblist: + for (obj, _) in dblist: for cont, match, fname in obj.gather(query, threshold_bp=threshold_bp): assert cont # all matches should be nonzero. @@ -95,9 +95,7 @@ def _find_best(dblist, query, threshold_bp): # update best match. best_cont = cont best_match = match - - # some objects may not have associated filename (e.g. SBTs) - best_filename = fname or filename # @CTB + best_filename = fname if not best_match: return None, None, None