Skip to content

Commit

Permalink
[MRG] Refactor the database loading code in sourmash_args (#1373)
Browse files Browse the repository at this point in the history
* refactor return signature of load_dbs_and_sigs

* more refactor - filename stuff

* add 'location' to SBT objects

* finish removing filename

* Update src/sourmash/sbt.py

Co-authored-by: Luiz Irber <luizirber@users.noreply.github.com>

Co-authored-by: Luiz Irber <luizirber@users.noreply.github.com>
  • Loading branch information
ctb and luizirber committed Mar 9, 2021
1 parent 5e66db9 commit b6c28d8
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
9 changes: 6 additions & 3 deletions src/sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(self, factory, *, d=2, storage=None, cache_size=None):
if cache_size is None:
cache_size = sys.maxsize
self._nodescache = _NodesCache(maxsize=cache_size)
self._location = None

def signatures(self):
for k in self.leaves():
Expand Down Expand Up @@ -389,7 +390,7 @@ def search(self, query, *args, **kwargs):
# tree search should always/only return matches above threshold
assert similarity >= threshold

results.append((similarity, leaf.data, None))
results.append((similarity, leaf.data, self.location))

return results

Expand Down Expand Up @@ -435,7 +436,7 @@ def gather(self, query, *args, **kwargs):
containment = query.minhash.contained_by(leaf_mh, True)

assert containment >= threshold, "containment {} not below threshold {}".format(containment, threshold)
results.append((containment, leaf.data, None))
results.append((containment, leaf.data, self.location))

results.sort(key=lambda x: -x[0])

Expand Down Expand Up @@ -758,7 +759,9 @@ def load(cls, location, *, leaf_loader=None, storage=None, print_version_warning
elif storage is None:
storage = klass(**jnodes['storage']['args'])

return loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size)
obj = loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size)
obj.location = location
return obj

@staticmethod
def _load_v1(jnodes, leaf_loader, dirname, storage, *, print_version_warning=True, cache_size=None):
Expand Down
13 changes: 6 additions & 7 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ def search_databases(query, databases, threshold, do_containment, best_only,
ignore_abundance, unload_data=False):
results = []
found_md5 = set()
for (obj, filename, filetype) in databases:
search_iter = obj.search(query, threshold=threshold,
for db in databases:
search_iter = db.search(query, threshold=threshold,
do_containment=do_containment,
ignore_abundance=ignore_abundance,
best_only=best_only,
unload_data=unload_data)

for (similarity, match, filename) in search_iter:
md5 = match.md5sum()
if md5 not in found_md5:
Expand Down Expand Up @@ -84,8 +85,8 @@ def _find_best(dblist, query, threshold_bp):
threshold_bp = int(threshold_bp / query_scaled) * query_scaled

# search across all databases
for (obj, filename, filetype) in dblist:
for cont, match, fname in obj.gather(query, threshold_bp=threshold_bp):
for db in dblist:
for cont, match, fname in db.gather(query, threshold_bp=threshold_bp):
assert cont # all matches should be nonzero.

# note, break ties based on name, to ensure consistent order.
Expand All @@ -94,9 +95,7 @@ def _find_best(dblist, query, threshold_bp):
# update best match.
best_cont = cont
best_match = match

# some objects may not have associated filename (e.g. SBTs)
best_filename = fname or filename
best_filename = fname

if not best_match:
return None, None, None
Expand Down
10 changes: 6 additions & 4 deletions src/sourmash/sourmash_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
Load one or more SBTs, LCAs, and/or signatures.
Check for compatibility with query.
This is basically a user-focused wrapping of _load_databases.
"""
query_ksize = query.minhash.ksize
query_moltype = get_moltype(query)
Expand All @@ -281,7 +283,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
siglist = _select_sigs(db, moltype=query_moltype, ksize=query_ksize)
siglist = filter_compatible_signatures(query, siglist, 1)
linear = LinearIndex(siglist, filename=filename)
databases.append((linear, filename, False))
databases.append(linear)

n_signatures += len(linear)

Expand All @@ -291,7 +293,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
is_similarity_query):
sys.exit(-1)

databases.append((db, filename, 'SBT'))
databases.append(db)
notify('loaded SBT {}', filename, end='\r')
n_databases += 1

Expand All @@ -304,7 +306,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
notify('loaded LCA {}', filename, end='\r')
n_databases += 1

databases.append((db, filename, 'LCA'))
databases.append(db)

# signature file
elif dbtype == DatabaseType.SIGLIST:
Expand All @@ -316,7 +318,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
sys.exit(-1)

linear = LinearIndex(siglist, filename=filename)
databases.append((linear, filename, 'signature'))
databases.append(linear)

notify('loaded {} signatures from {}', len(linear),
filename, end='\r')
Expand Down

0 comments on commit b6c28d8

Please sign in to comment.