Skip to content

Commit

Permalink
Refactor and clean up gather code (#517)
Browse files Browse the repository at this point in the history
* do appropriate renaming
* rename R_ stuff to scaled
* add comments to sbt find function
* share found signatures across repeated searches
* get rid of unnecessary sorting
  • Loading branch information
ctb authored and luizirber committed Dec 21, 2018
1 parent 726913b commit e442185
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 32 deletions.
4 changes: 2 additions & 2 deletions sourmash/_minhash.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ cdef class MinHash(object):
if with_abundance and self.track_abundance:
return dict(zip(mh.mins, mh.abunds))
else:
return [it for it in sorted(deref(self._this).mins)]
return deref(self._this).mins

def get_hashes(self):
return self.get_mins()
Expand Down Expand Up @@ -368,7 +368,7 @@ cdef class MinHash(object):
return 0.0
return self.count_common(other) / len(self.get_mins())

def similarity_ignore_maxhash(self, MinHash other):
def containment_ignore_maxhash(self, MinHash other):
a = set(self.get_mins())
if not a:
return 0.0
Expand Down
14 changes: 14 additions & 0 deletions sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,37 @@ def add_node(self, node):
p = self.parent(p.pos)

def find(self, search_fn, *args, **kwargs):
"Search the tree using `search_fn`."

# initialize search queue with top node of tree
matches = []
visited, queue = set(), [0]

# while the queue is not empty, load each node and apply search
# function.
while queue:
node_p = queue.pop(0)
node_g = self.nodes.get(node_p, None)

# repair while searching.
if node_g is None:
if node_p in self.missing_nodes:
self._rebuild_node(node_p)
node_g = self.nodes[node_p]
else:
continue

# if we have not visited this node before,
if node_p not in visited:
visited.add(node_p)

# apply search fn. If return false, truncate search.
if search_fn(node_g, *args):

# leaf node? it's a match!
if isinstance(node_g, Leaf):
matches.append(node_g)
# internal node? descend.
elif isinstance(node_g, Node):
if kwargs.get('dfs', True): # defaults search to dfs
for c in self.children(node_p):
Expand Down
23 changes: 12 additions & 11 deletions sourmash/sbtmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,38 +203,39 @@ def search_minhashes_containment(node, sig, threshold,
return 0


class SearchMinHashesFindBestIgnoreMaxHash(object):
def __init__(self):
self.best_match = 0.
class GatherMinHashesFindBestIgnoreMaxHash(object):
def __init__(self, initial_best_match=0.0):
self.best_match = initial_best_match

def search(self, node, sig, threshold, results=None):
mins = sig.minhash.get_mins()

score = 0
if not len(mins):
return 0

if isinstance(node, SigLeaf):
max_scaled = max(node.data.minhash.scaled, sig.minhash.scaled)

mh1 = node.data.minhash.downsample_scaled(max_scaled)
mh2 = sig.minhash.downsample_scaled(max_scaled)
matches = mh1.count_common(mh2)
else: # Node or Leaf, Nodegraph by minhash comparison
else: # Nodegraph by minhash comparison
get = node.data.get
matches = sum(1 for value in mins if get(value))

score = 0
if not len(mins):
return 0

score = float(matches) / len(mins)

# store results if we have passed in an appropriate dictionary
if results is not None:
results[node.name] = score

if score >= threshold:
# have we done better than this? if yes, truncate.
if float(matches) / len(mins) > self.best_match:
# have we done better than this? if no, truncate searches below.
if score >= self.best_match:
# update best if it's a leaf node...
if isinstance(node, SigLeaf):
self.best_match = float(matches) / len(mins)
self.best_match = score
return 1

return 0
48 changes: 29 additions & 19 deletions sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .logging import notify, error
from .signature import SourmashSignature
from .sbtmh import search_minhashes, search_minhashes_containment
from .sbtmh import SearchMinHashesFindBest, SearchMinHashesFindBestIgnoreMaxHash
from .sbtmh import SearchMinHashesFindBest, GatherMinHashesFindBestIgnoreMaxHash
from ._minhash import get_max_hash_for_scaled


Expand Down Expand Up @@ -107,24 +107,29 @@ def gather_databases(query, databases, threshold_bp, ignore_abundance):
import numpy as np
orig_abunds = orig_query.minhash.get_mins(with_abundance=True)

# calculate the band size/resolution R for the genome
R_metagenome = orig_query.minhash.scaled
# store the scaled value for the query
orig_scaled = orig_query.minhash.scaled

# define a function to do a 'best' search and get only top match.
def find_best(dblist, query):
# CTB: could optimize by sharing scores across searches, i.e.
# a good early score truncates later searches.
def find_best(dblist, query, remainder):

# precompute best containment from all of the remainders
best_ctn_sofar = 0.0
for x in remainder:
ctn = query.minhash.containment_ignore_maxhash(x.minhash)
if ctn > best_ctn_sofar:
best_ctn_sofar = ctn

results = []
for (obj, filename, filetype) in dblist:
# search a tree
if filetype == 'SBT':
tree = obj
search_fn = SearchMinHashesFindBestIgnoreMaxHash().search
search_fn = GatherMinHashesFindBestIgnoreMaxHash(best_ctn_sofar).search

for leaf in tree.find(search_fn, query, 0.0):
for leaf in tree.find(search_fn, query, best_ctn_sofar):
leaf_e = leaf.data.minhash
similarity = query.minhash.similarity_ignore_maxhash(leaf_e)
similarity = query.minhash.containment_ignore_maxhash(leaf_e)
if similarity > 0.0:
results.append((similarity, leaf.data))
# or an LCA database
Expand All @@ -139,7 +144,7 @@ def find_best(dblist, query):
# search a signature
else:
for ss in obj:
similarity = query.minhash.similarity_ignore_maxhash(ss.minhash)
similarity = query.minhash.containment_ignore_maxhash(ss.minhash)
if similarity > 0.0:
results.append((similarity, ss))

Expand All @@ -149,6 +154,10 @@ def find_best(dblist, query):
# take the best result
results.sort(key=lambda x: (-x[0], x[1].name())) # reverse sort on similarity, and then on name
best_similarity, best_leaf = results[0]

for x in results[1:]:
remainder.add(x[1])

return best_similarity, best_leaf, filename


Expand All @@ -164,9 +173,10 @@ def build_new_signature(mins, template_sig, scaled=None):
new_mins = query.minhash.get_hashes()
query = build_new_signature(new_mins, orig_query)

R_comparison = 0
cmp_scaled = 0
remainder = set()
while 1:
best_similarity, best_leaf, filename = find_best(databases, query)
best_similarity, best_leaf, filename = find_best(databases, query, remainder)
if not best_leaf: # no matches at all!
break

Expand All @@ -180,15 +190,15 @@ def build_new_signature(mins, template_sig, scaled=None):
error('Please prepare database of sequences with --scaled')
sys.exit(-1)

R_genome = best_leaf.minhash.scaled
match_scaled = best_leaf.minhash.scaled

# pick the highest R / lowest resolution
R_comparison = max(R_comparison, R_metagenome, R_genome)
# pick the highest scaled / lowest resolution
cmp_scaled = max(cmp_scaled, match_scaled, orig_scaled)

# eliminate mins under this new resolution.
# (CTB note: this means that if a high scaled/low res signature is
# found early on, resolution will be low from then on.)
new_max_hash = get_max_hash_for_scaled(R_comparison)
new_max_hash = get_max_hash_for_scaled(cmp_scaled)
query_mins = set([ i for i in query_mins if i < new_max_hash ])
found_mins = set([ i for i in found_mins if i < new_max_hash ])
orig_mins = set([ i for i in orig_mins if i < new_max_hash ])
Expand All @@ -197,7 +207,7 @@ def build_new_signature(mins, template_sig, scaled=None):
# calculate intersection:
intersect_mins = query_mins.intersection(found_mins)
intersect_orig_mins = orig_mins.intersection(found_mins)
intersect_bp = R_comparison * len(intersect_orig_mins)
intersect_bp = cmp_scaled * len(intersect_orig_mins)

if intersect_bp < threshold_bp: # hard cutoff for now
notify('found less than {} in common. => exiting',
Expand All @@ -210,7 +220,7 @@ def build_new_signature(mins, template_sig, scaled=None):
f_orig_query = len(intersect_orig_mins) / float(len(orig_mins))

# calculate fractions wrt second denominator - metagenome size
orig_mh = orig_query.minhash.downsample_scaled(R_comparison)
orig_mh = orig_query.minhash.downsample_scaled(cmp_scaled)
query_n_mins = len(orig_mh)
f_unique_to_query = len(intersect_mins) / float(query_n_mins)

Expand Down Expand Up @@ -241,7 +251,7 @@ def build_new_signature(mins, template_sig, scaled=None):

# construct a new query, minus the previous one.
query_mins -= set(found_mins)
query = build_new_signature(query_mins, orig_query, R_comparison)
query = build_new_signature(query_mins, orig_query, cmp_scaled)

weighted_missed = sum((orig_abunds[k] for k in query_mins)) \
/ sum_abunds
Expand Down

0 comments on commit e442185

Please sign in to comment.