Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Refactor the database loading code in sourmash_args #1373

Merged
merged 6 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(self, factory, *, d=2, storage=None, cache_size=None):
if cache_size is None:
cache_size = sys.maxsize
self._nodescache = _NodesCache(maxsize=cache_size)
self.location = None
ctb marked this conversation as resolved.
Show resolved Hide resolved

def signatures(self):
for k in self.leaves():
Expand Down Expand Up @@ -389,7 +390,7 @@ def search(self, query, *args, **kwargs):
# tree search should always/only return matches above threshold
assert similarity >= threshold

results.append((similarity, leaf.data, None))
results.append((similarity, leaf.data, self.location))

return results

Expand Down Expand Up @@ -435,7 +436,7 @@ def gather(self, query, *args, **kwargs):
containment = query.minhash.contained_by(leaf_mh, True)

assert containment >= threshold, "containment {} not below threshold {}".format(containment, threshold)
results.append((containment, leaf.data, None))
results.append((containment, leaf.data, self.location))

results.sort(key=lambda x: -x[0])

Expand Down Expand Up @@ -758,7 +759,9 @@ def load(cls, location, *, leaf_loader=None, storage=None, print_version_warning
elif storage is None:
storage = klass(**jnodes['storage']['args'])

return loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size)
obj = loader(jnodes, leaf_loader, dirname, storage, print_version_warning=print_version_warning, cache_size=cache_size)
obj.location = location
ctb marked this conversation as resolved.
Show resolved Hide resolved
return obj

@staticmethod
def _load_v1(jnodes, leaf_loader, dirname, storage, *, print_version_warning=True, cache_size=None):
Expand Down
15 changes: 7 additions & 8 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ def search_databases(query, databases, threshold, do_containment, best_only,
ignore_abundance, unload_data=False):
results = []
found_md5 = set()
for (obj, filename, filetype) in databases:
search_iter = obj.search(query, threshold=threshold,
for db in databases:
search_iter = db.search(query, threshold=threshold,
do_containment=do_containment,
ignore_abundance=ignore_abundance,
best_only=best_only,
unload_data=unload_data)

for (similarity, match, filename) in search_iter:
md5 = match.md5sum()
if md5 not in found_md5:
Expand Down Expand Up @@ -84,8 +85,8 @@ def _find_best(dblist, query, threshold_bp):
threshold_bp = int(threshold_bp / query_scaled) * query_scaled

# search across all databases
for (obj, filename, filetype) in dblist:
for cont, match, fname in obj.gather(query, threshold_bp=threshold_bp):
for db in dblist:
for cont, match, fname in db.gather(query, threshold_bp=threshold_bp):
assert cont # all matches should be nonzero.

# note, break ties based on name, to ensure consistent order.
Expand All @@ -94,9 +95,7 @@ def _find_best(dblist, query, threshold_bp):
# update best match.
best_cont = cont
best_match = match

# some objects may not have associated filename (e.g. SBTs)
best_filename = fname or filename
best_filename = fname

if not best_match:
return None, None, None
Expand Down Expand Up @@ -215,7 +214,7 @@ def gather_databases(query, databases, threshold_bp, ignore_abundance):
average_abund=average_abund,
median_abund=median_abund,
std_abund=std_abund,
filename=filename,
filename=filename, # @CTB
ctb marked this conversation as resolved.
Show resolved Hide resolved
md5=best_match.md5sum(),
name=str(best_match),
match=best_match,
Expand Down
10 changes: 6 additions & 4 deletions src/sourmash/sourmash_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
Load one or more SBTs, LCAs, and/or signatures.

Check for compatibility with query.

This is basically a user-focused wrapping of _load_databases.
"""
query_ksize = query.minhash.ksize
query_moltype = get_moltype(query)
Expand All @@ -281,7 +283,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
siglist = _select_sigs(db, moltype=query_moltype, ksize=query_ksize)
siglist = filter_compatible_signatures(query, siglist, 1)
linear = LinearIndex(siglist, filename=filename)
databases.append((linear, filename, False))
databases.append(linear)

n_signatures += len(linear)

Expand All @@ -291,7 +293,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
is_similarity_query):
sys.exit(-1)

databases.append((db, filename, 'SBT'))
databases.append(db)
notify('loaded SBT {}', filename, end='\r')
n_databases += 1

Expand All @@ -304,7 +306,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
notify('loaded LCA {}', filename, end='\r')
n_databases += 1

databases.append((db, filename, 'LCA'))
databases.append(db)

# signature file
elif dbtype == DatabaseType.SIGLIST:
Expand All @@ -316,7 +318,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
sys.exit(-1)

linear = LinearIndex(siglist, filename=filename)
databases.append((linear, filename, 'signature'))
databases.append(linear)

notify('loaded {} signatures from {}', len(linear),
filename, end='\r')
Expand Down