Skip to content

Commit

Permalink
Merge pull request #380 from dib-lab/fix/sbt_search
Browse files Browse the repository at this point in the history
[MRG] Fixes and cleanups for query and subject incompatibility in sourmash search.
  • Loading branch information
ctb committed Feb 7, 2018
2 parents bf91ee2 + 026dca2 commit f469b5a
Show file tree
Hide file tree
Showing 13 changed files with 555 additions and 109 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy
matplotlib
scipy
Cython
khmer
khmer>=2.1,<3
sphinx
alabaster
recommonmark
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
language="c++",
extra_compile_args=EXTRA_COMPILE_ARGS,
extra_link_args=EXTRA_LINK_ARGS)],
"install_requires": ["screed>=0.9", "ijson", "khmer>2.0<3.0"],
"install_requires": ["screed>=0.9", "ijson", "khmer>=2.1<3.0"],
"setup_requires": ['Cython>=0.25.2', "setuptools>=18.0"],
"extras_require": {
'test' : ['pytest', 'pytest-cov', 'numpy', 'matplotlib', 'scipy'],
Expand Down
2 changes: 1 addition & 1 deletion sourmash_lib/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0a2
2.0.0a3
3 changes: 3 additions & 0 deletions sourmash_lib/_minhash.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ cdef class MinHash(object):
return self.downsample_scaled(new_scaled)

def downsample_scaled(self, new_num):
if self.num:
raise ValueError('num != 0 - cannot downsample a standard MinHash')

max_hash = self.max_hash
if max_hash is None:
raise ValueError('no max_hash available - cannot downsample')
Expand Down
57 changes: 32 additions & 25 deletions sourmash_lib/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import signature as sig
from . import sourmash_args
from .logging import notify, error, print_results, set_quiet
from .search import format_bp

from .sourmash_args import DEFAULT_LOAD_K
DEFAULT_COMPUTE_K = '21,31,51'
Expand Down Expand Up @@ -610,6 +611,9 @@ def sbt_combine(args):


def index(args):
"""
Build an Sequence Bloom Tree index of the given signatures.
"""
import sourmash_lib.sbtmh

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -657,6 +661,8 @@ def index(args):
n = 0
ksizes = set()
moltypes = set()
nums = set()
scaleds = set()
for f in inp_files:
siglist = sig.load_signatures(f, ksize=args.ksize,
select_moltype=moltype)
Expand All @@ -665,6 +671,8 @@ def index(args):
for ss in siglist:
ksizes.add(ss.minhash.ksize)
moltypes.add(sourmash_args.get_moltype(ss))
nums.add(ss.minhash.num)
scaleds.add(ss.minhash.scaled)

leaf = sourmash_lib.sbtmh.SigLeaf(ss.md5sum(), ss)
tree.add_node(leaf)
Expand All @@ -678,6 +686,16 @@ def index(args):
", ".join(map(str, ksizes)), ", ".join(moltypes))
sys.exit(-1)

if nums == { 0 } and len(scaleds) == 1:
pass # good
elif scaleds == { 0 } and len(nums) == 1:
pass # also good
else:
error('trying to build an SBT with incompatible signatures.')
error('nums = {}; scaleds = {}', repr(nums), repr(scaleds))
sys.exit(-1)


# did we load any!?
if n == 0:
error('no signatures found to load into tree!? failing.')
Expand Down Expand Up @@ -724,11 +742,9 @@ def search(args):
query = sourmash_args.load_query_signature(args.query,
ksize=args.ksize,
select_moltype=moltype)
query_moltype = sourmash_args.get_moltype(query)
query_ksize = query.minhash.ksize
notify('loaded query: {}... (k={}, {})', query.name()[:30],
query_ksize,
query_moltype)
query.minhash.ksize,
sourmash_args.get_moltype(query))

# downsample if requested
if args.scaled:
Expand All @@ -741,8 +757,8 @@ def search(args):
query.minhash = query.minhash.downsample_scaled(args.scaled)

# set up the search databases
databases = sourmash_args.load_sbts_and_sigs(args.databases,
query_ksize, query_moltype,
databases = sourmash_args.load_sbts_and_sigs(args.databases, query,
not args.containment,
args.traverse_directory)

if not len(databases):
Expand Down Expand Up @@ -903,11 +919,9 @@ def gather(args):
query = sourmash_args.load_query_signature(args.query,
ksize=args.ksize,
select_moltype=moltype)
query_moltype = sourmash_args.get_moltype(query)
query_ksize = query.minhash.ksize
notify('loaded query: {}... (k={}, {})', query.name()[:30],
query_ksize,
query_moltype)
query.minhash.ksize,
sourmash_args.get_moltype(query))

# verify signature was computed right.
if query.minhash.max_hash == 0:
Expand All @@ -926,8 +940,7 @@ def gather(args):
sys.exit(-1)

# set up the search databases
databases = sourmash_args.load_sbts_and_sigs(args.databases,
query_ksize, query_moltype,
databases = sourmash_args.load_sbts_and_sigs(args.databases, query, False,
args.traverse_directory)

if not len(databases):
Expand Down Expand Up @@ -995,7 +1008,7 @@ def gather(args):
outname = args.output_unassigned.name
notify('saving unassigned hashes to "{}"', outname)

e = sourmash_lib.MinHash(ksize=query_ksize, n=0,
e = sourmash_lib.MinHash(ksize=query.minhash.ksize, n=0,
max_hash=new_max_hash)
e.add_many(query.minhash.get_mins())
sig.save_signatures([ sig.SourmashSignature(e) ],
Expand Down Expand Up @@ -1044,24 +1057,18 @@ def watch(args):

tree = sourmash_lib.load_sbt_index(args.sbt_name)

def get_ksize(tree):
"""Walk nodes in `tree` to find out ksize"""
for node in tree.nodes.values():
if isinstance(node, sourmash_lib.sbtmh.SigLeaf):
return node.data.minhash.ksize

# deduce ksize from the SBT we are loading
# check ksize from the SBT we are loading
ksize = args.ksize
if ksize is None:
ksize = get_ksize(tree)
leaf = next(iter(tree.leaves()))
tree_mh = leaf.data.minhash
ksize = tree_mh.ksize

E = sourmash_lib.MinHash(ksize=ksize, n=args.num_hashes,
is_protein=is_protein)
streamsig = sig.SourmashSignature(E, filename='stdin',
name=args.name)
streamsig = sig.SourmashSignature(E, filename='stdin', name=args.name)

notify('Computing signature for k={}, {} from stdin',
ksize, moltype)
notify('Computing signature for k={}, {} from stdin', ksize, moltype)

def do_search():
search_fn = SearchMinHashesFindBest().search
Expand Down
62 changes: 44 additions & 18 deletions sourmash_lib/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ def __init__(self, factory, d=2, storage=None):
self.nodes = defaultdict(lambda: None)
self.missing_nodes = set()
self.d = d
self.max_node = 0
self.next_node = 0
self.storage = storage

def new_node_pos(self, node):
while self.nodes[self.max_node] is not None:
self.max_node += 1
return self.max_node
while self.nodes.get(self.next_node, None) is not None:
self.next_node += 1
return self.next_node

def add_node(self, node):
pos = self.new_node_pos(node)
Expand Down Expand Up @@ -178,7 +178,7 @@ def find(self, search_fn, *args, **kwargs):
visited, queue = set(), [0]
while queue:
node_p = queue.pop(0)
node_g = self.nodes[node_p]
node_g = self.nodes.get(node_p, None)
if node_g is None:
if node_p in self.missing_nodes:
self._rebuild_node(node_p)
Expand Down Expand Up @@ -210,7 +210,7 @@ def _rebuild_node(self, pos=0):
(the default).
"""

node = self.nodes[pos]
node = self.nodes.get(pos, None)
if node is not None:
# this node was already build, skip
return
Expand Down Expand Up @@ -244,7 +244,8 @@ def parent(self, pos):
if pos == 0:
return None
p = int(math.floor((pos - 1) / self.d))
return NodePos(p, self.nodes[p])
node = self.nodes.get(p, None)
return NodePos(p, node)

def children(self, pos):
"""Return all children nodes for node at position ``pos``.
Expand Down Expand Up @@ -281,7 +282,8 @@ def child(self, parent, pos):
child node.
"""
cd = self.d * parent + pos + 1
return NodePos(cd, self.nodes[cd])
node = self.nodes.get(cd, None)
return NodePos(cd, node)

def save(self, path, storage=None, sparseness=0.0):
"""Saves an SBT description locally and node data to a storage.
Expand Down Expand Up @@ -353,7 +355,7 @@ def save(self, path, storage=None, sparseness=0.0):
data['filename'] = node.save(data['filename'])
structure[i] = data

notify("{} of {} nodes saved".format(n, total_nodes), end='\r')
notify("{} of {} nodes saved".format(n+1, total_nodes), end='\r')

notify("\nFinished saving nodes, now saving SBT json file.")
info['nodes'] = structure
Expand Down Expand Up @@ -398,7 +400,7 @@ def load(cls, location, leaf_loader=None, storage=None):
try:
x.count(10)
except TypeError:
raise Exception("khmer version is too old; need >= 2.1.")
raise Exception("khmer version is too old; need >= 2.1,<3")

if leaf_loader is None:
leaf_loader = Leaf.load
Expand Down Expand Up @@ -516,10 +518,34 @@ def _load_v3(cls, info, leaf_loader, dirname, storage):
tree.nodes = sbt_nodes
tree.missing_nodes = {i for i in range(max_node)
if i not in sbt_nodes}
tree.max_node = max_node
# TODO: this might not be true with combine...
tree.next_node = max_node

tree._fill_max_n_below()

return tree

def _fill_max_n_below(self):
for i, n in self.nodes.items():
if isinstance(n, Leaf):
parent = self.parent(i)
if parent.pos not in self.missing_nodes:
max_n_below = parent.node.metadata.get('max_n_below', 0)
max_n_below = max(len(n.data.minhash.get_mins()),
max_n_below)
parent.node.metadata['max_n_below'] = max_n_below

current = parent
parent = self.parent(parent.pos)
while parent and parent.pos not in self.missing_nodes:
max_n_below = parent.node.metadata.get('max_n_below', 0)
max_n_below = max(current.node.metadata['max_n_below'],
max_n_below)
parent.node.metadata['max_n_below'] = max_n_below
current = parent
parent = self.parent(parent.pos)


def print_dot(self):
print("""
digraph G {
Expand All @@ -543,7 +569,7 @@ def print(self):
visited, stack = set(), [0]
while stack:
node_p = stack.pop()
node_g = self.nodes[node_p]
node_g = self.nodes.get(node_p, None)
if node_p not in visited and node_g is not None:
visited.add(node_p)
depth = int(math.floor(math.log(node_p + 1, self.d)))
Expand Down Expand Up @@ -573,7 +599,9 @@ def _leaves(self, pos=0):
yield (i, node)

def leaves(self):
return [c for c in self.nodes.values() if isinstance(c, Leaf)]
for c in self.nodes.values():
if isinstance(c, Leaf):
yield c

def combine(self, other):
larger, smaller = self, other
Expand All @@ -593,22 +621,20 @@ def combine(self, other):
for level in range(1, levels + 1):
for tree in (larger, smaller):
for pos in range(n_previous, n_next):
if tree.nodes[pos] is not None:
if tree.nodes.get(pos, None) is not None:
new_node = copy(tree.nodes[pos])
if isinstance(new_node, Node):
# An internal node, we need to update the name
new_node.name = "internal.{}".format(current_pos)
new_nodes[current_pos] = new_node
else:
del tree.nodes[pos]
current_pos += 1
n_previous = n_next
n_next = n_previous + int(self.d ** level)
current_pos = n_next

# reset max_node, next time we add a node it will find the next
# reset next_node, next time we add a node it will find the next
# empty position
self.max_node = 2
self.next_node = 2

# TODO: do we want to return a new tree, or merge into this one?
self.nodes = new_nodes
Expand Down
17 changes: 12 additions & 5 deletions sourmash_lib/sbtmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,34 @@ def data(self, new_data):

def search_minhashes(node, sig, threshold, results=None, downsample=True):
mins = sig.minhash.get_mins()
score = 0

if isinstance(node, SigLeaf):
try:
matches = node.data.minhash.count_common(sig.minhash)
score = node.data.minhash.similarity(sig.minhash)
except Exception as e:
if 'mismatch in max_hash' in str(e) and downsample:
xx = sig.minhash.downsample_max_hash(node.data.minhash)
yy = node.data.minhash.downsample_max_hash(sig.minhash)

matches = yy.count_common(xx)
score = yy.similarity(xx)
else:
raise

else: # Node or Leaf, Nodegraph by minhash comparison
matches = sum(1 for value in mins if node.data.get(value))
if len(mins):
matches = sum(1 for value in mins if node.data.get(value))
max_mins = node.metadata.get('max_n_below', -1)
if max_mins == -1:
raise Exception('cannot do similarity search on this SBT; need to rebuild.')
score = float(matches) / max_mins

if results is not None:
results[node.name] = float(matches) / len(mins)
results[node.name] = score

if len(mins) and float(matches) / len(mins) >= threshold:
if score >= threshold:
return 1

return 0


Expand Down
10 changes: 7 additions & 3 deletions sourmash_lib/search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import division
from collections import namedtuple
import sys

import sourmash_lib
from .logging import notify
from .logging import notify, error
from .signature import SourmashSignature
from .sbtmh import search_minhashes, search_minhashes_containment
from .sbtmh import SearchMinHashesFindBest
Expand Down Expand Up @@ -47,8 +48,11 @@ def search_databases(query, databases, threshold, do_containment, best_only):
tree = sbt_or_siglist
for leaf in tree.find(search_fn, query, threshold):
similarity = query_match(leaf.data)
if similarity >= threshold and \
leaf.data.md5sum() not in found_md5:

# tree search should always/only return matches above threshold
assert similarity >= threshold

if leaf.data.md5sum() not in found_md5:
sr = SearchResult(similarity=similarity,
match_sig=leaf.data,
md5=leaf.data.md5sum(),
Expand Down
Loading

0 comments on commit f469b5a

Please sign in to comment.