Skip to content

Commit

Permalink
add 'location' to SBT objects
Browse files Browse the repository at this point in the history
  • Loading branch information
ctb committed Mar 6, 2021
1 parent 23eea6d commit d7e3064
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 6 additions & 3 deletions src/sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(self, factory, *, d=2, storage=None, cache_size=None):
if cache_size is None:
cache_size = sys.maxsize
self._nodescache = _NodesCache(maxsize=cache_size)
self.location = None

def signatures(self):
for k in self.leaves():
Expand Down Expand Up @@ -389,7 +390,7 @@ def search(self, query, *args, **kwargs):
# tree search should always/only return matches above threshold
assert similarity >= threshold

results.append((similarity, leaf.data, None))
results.append((similarity, leaf.data, self.location))

return results

Expand Down Expand Up @@ -435,7 +436,7 @@ def gather(self, query, *args, **kwargs):
containment = query.minhash.contained_by(leaf_mh, True)

assert containment >= threshold, "containment {} not below threshold {}".format(containment, threshold)
results.append((containment, leaf.data, None))
results.append((containment, leaf.data, self.location))

results.sort(key=lambda x: -x[0])

Expand Down Expand Up @@ -758,7 +759,9 @@ def load(cls, location, *, leaf_loader=None, storage=None, print_version_warning
elif storage is None:
storage = klass(**jnodes['storage']['args'])

return loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size)
obj = loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size)
obj.location = location
return obj

@staticmethod
def _load_v1(jnodes, leaf_loader, dirname, storage, *, print_version_warning=True, cache_size=None):
Expand Down
6 changes: 2 additions & 4 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _find_best(dblist, query, threshold_bp):
threshold_bp = int(threshold_bp / query_scaled) * query_scaled

# search across all databases
for (obj, filename) in dblist:
for (obj, _) in dblist:
for cont, match, fname in obj.gather(query, threshold_bp=threshold_bp):
assert cont # all matches should be nonzero.

Expand All @@ -95,9 +95,7 @@ def _find_best(dblist, query, threshold_bp):
# update best match.
best_cont = cont
best_match = match

# some objects may not have associated filename (e.g. SBTs)
best_filename = fname or filename # @CTB
best_filename = fname

if not best_match:
return None, None, None
Expand Down

0 comments on commit d7e3064

Please sign in to comment.