From c0e3476b9dc63b7a3c00d21d23cb898c8b4e47f6 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Mon, 10 Apr 2017 14:04:01 -0700 Subject: [PATCH] SBT feature: combine trees (#143) * Add a new method to SBT, combine. * Use defaultdict instead of list for SBT, fix bug during insertion * Fix node position calculation * Simplify print_dot; use sets in tests * Fix max_node calculation, add a test * Add a sbt_combine command, with tests --- sourmash_lib/__main__.py | 7 ++- sourmash_lib/commands.py | 29 ++++++++++ sourmash_lib/sbt.py | 109 ++++++++++++++++++++++++------------ tests/conftest.py | 5 ++ tests/sourmash_tst_utils.py | 6 ++ tests/test_sbt.py | 91 +++++++++++++++++++++++------- tests/test_sourmash.py | 38 +++++++++++++ 7 files changed, 227 insertions(+), 58 deletions(-) diff --git a/sourmash_lib/__main__.py b/sourmash_lib/__main__.py index 2d32ddcbc7..edbf3b2a2c 100644 --- a/sourmash_lib/__main__.py +++ b/sourmash_lib/__main__.py @@ -8,7 +8,8 @@ from .logging import notify, error from .commands import (categorize, compare, compute, convert, dump, import_csv, - sbt_gather, sbt_index, sbt_search, search, plot, watch) + sbt_gather, sbt_index, sbt_combine, sbt_search, search, + plot, watch) def main(): @@ -17,7 +18,8 @@ def main(): 'import_csv': import_csv, 'dump': dump, 'sbt_index': sbt_index, 'sbt_search': sbt_search, 'categorize': categorize, 'sbt_gather': sbt_gather, - 'watch': watch, 'convert': convert} + 'watch': watch, 'convert': convert, + 'sbt_combine': sbt_combine} parser = argparse.ArgumentParser( description='work with RNAseq signatures', usage='''sourmash [] @@ -33,6 +35,7 @@ def main(): convert Convert signatures from YAML to JSON. sbt_index Index signatures with a Sequence Bloom Tree. +sbt_combine Combine multiple Sequence Bloom Trees into a new one. sbt_search Search a Sequence Bloom Tree. categorize Categorize signatures with a SBT. sbt_gather Search a signature for multiple matches. diff --git a/sourmash_lib/commands.py b/sourmash_lib/commands.py index ab6aa4cdc4..fd8a738252 100644 --- a/sourmash_lib/commands.py +++ b/sourmash_lib/commands.py @@ -480,6 +480,35 @@ def dump(args): fp.close() +def sbt_combine(args): + from sourmash_lib.sbt import SBT, GraphFactory + from sourmash_lib.sbtmh import SigLeaf + + parser = argparse.ArgumentParser() + parser.add_argument('sbt_name', help='name to save SBT into') + parser.add_argument('sbts', nargs='+', + help='SBTs to combine to a new SBT') + parser.add_argument('-x', '--bf-size', type=float, default=1e5) + + sourmash_args.add_moltype_args(parser) + + args = parser.parse_args(args) + moltype = sourmash_args.calculate_moltype(args) + + inp_files = list(args.sbts) + notify('combining {} SBTs', len(inp_files)) + + tree = SBT.load(inp_files.pop(0), leaf_loader=SigLeaf.load) + + for f in inp_files: + new_tree = SBT.load(f, leaf_loader=SigLeaf.load) + # TODO: check if parameters are the same for both trees! + tree.combine(new_tree) + + notify('saving SBT under "{}"', args.sbt_name) + tree.save(args.sbt_name) + + def sbt_index(args): from sourmash_lib.sbt import SBT, GraphFactory from sourmash_lib.sbtmh import search_minhashes, SigLeaf diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 08245a4584..abc1f35734 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -47,17 +47,13 @@ def search_transcript(node, seq, threshold): from __future__ import print_function, unicode_literals, division -from collections import namedtuple, Mapping -import hashlib +from collections import namedtuple, Mapping, defaultdict +from copy import copy import json import math import os -import random -import shutil -from tempfile import NamedTemporaryFile import khmer -from khmer import khmer_args from random import randint from numpy import array @@ -78,19 +74,14 @@ class SBT(object): def __init__(self, factory, d=2): self.factory = factory - self.nodes = [None] + self.nodes = defaultdict(lambda: None) self.d = d + self.max_node = 0 def new_node_pos(self, node): - try: - pos = self.nodes.index(None) - except ValueError: - # There aren't any empty positions left. - # Extend array - height = math.floor(math.log(len(self.nodes), self.d)) + 1 - self.nodes += [None] * int(self.d ** height) - pos = self.nodes.index(None) - return pos + while self.nodes[self.max_node] is not None: + self.max_node += 1 + return self.max_node def add_node(self, node): pos = self.new_node_pos(node) @@ -106,6 +97,8 @@ def add_node(self, node): # - add Leaf, update parent # 3) parent is a Node (no position available) # - this is covered by case 1 + # 4) parent is None + # this can happen with d != 2, in this case create the parent node p = self.parent(pos) if isinstance(p.node, Leaf): # Create a new internal node @@ -123,6 +116,12 @@ def add_node(self, node): elif isinstance(p.node, Node): self.nodes[pos] = node node.update(p.node) + elif p.node is None: + n = Node(self.factory, name="internal." + str(p.pos)) + self.nodes[p.pos] = n + c1 = self.children(p.pos)[0] + self.nodes[c1.pos] = node + node.update(n) # update all parents! p = self.parent(p.pos) @@ -243,15 +242,14 @@ def _load_v1(jnodes, leaf_loader, dirname): # TODO error! raise ValueError("Empty tree!") - sbt_nodes = [] + sbt_nodes = defaultdict(lambda: None) sample_bf = os.path.join(dirname, jnodes[0]['filename']) ksize, tablesize, ntables = khmer.extract_nodegraph_info(sample_bf)[:3] factory = GraphFactory(ksize, tablesize, ntables) - for jnode in jnodes: + for i, jnode in enumerate(jnodes): if jnode is None: - sbt_nodes.append(None) continue if 'internal' in jnode['filename']: @@ -260,7 +258,7 @@ def _load_v1(jnodes, leaf_loader, dirname): else: sbt_node = leaf_loader(jnode, dirname) - sbt_nodes.append(sbt_node) + sbt_nodes[i] = sbt_node tree = SBT(factory) tree.nodes = sbt_nodes @@ -274,15 +272,14 @@ def _load_v2(cls, info, leaf_loader, dirname): if nodes[0] is None: raise ValueError("Empty tree!") - sbt_nodes = [] + sbt_nodes = defaultdict(lambda: None) sample_bf = os.path.join(dirname, nodes[0]['filename']) k, size, ntables = khmer.extract_nodegraph_info(sample_bf)[:3] factory = GraphFactory(k, size, ntables) - for i, node in sorted(nodes.items()): + for k, node in nodes.items(): if node is None: - sbt_nodes.append(None) continue if 'internal' in node['filename']: @@ -291,7 +288,7 @@ def _load_v2(cls, info, leaf_loader, dirname): else: sbt_node = leaf_loader(node, dirname) - sbt_nodes.append(sbt_node) + sbt_nodes[k] = sbt_node tree = cls(factory, d=info['d']) tree.nodes = sbt_nodes @@ -308,16 +305,13 @@ def print_dot(self): edge [arrowsize=0.8]; """) - for i, node in iter(self): - if node is None: - continue - - p = self.parent(i) - if p is not None: - if isinstance(node, Leaf): - print('"', p.pos, '"', '->', '"', node.name, '";') - else: - print('"', p.pos, '"', '->', '"', i, '";') + for i, node in list(self.nodes.items()): + if isinstance(node, Node): + print('"{}" [shape=box fillcolor=gray style=filled]'.format( + node.name)) + for j, child in self.children(i): + if child is not None: + print('"{}" -> "{}"'.format(node.name, child.name)) print("}") def print(self): @@ -334,9 +328,51 @@ def print(self): if c.pos not in visited) def __iter__(self): - for i, node in enumerate(self.nodes): + for i, node in self.nodes.items(): yield (i, node) + def leaves(self): + return [c for c in self.nodes.values() if isinstance(c, Leaf)] + + def combine(self, other): + larger, smaller = self, other + if len(other.nodes) > len(self.nodes): + larger, smaller = other, self + + n = Node(self.factory, name="internal.0") + larger.nodes[0].update(n) + smaller.nodes[0].update(n) + new_nodes = defaultdict(lambda: None) + new_nodes[0] = n + + levels = int(math.ceil(math.log(len(larger.nodes), self.d))) + 1 + current_pos = 1 + n_previous = 0 + n_next = 1 + for level in range(1, levels + 1): + for tree in (larger, smaller): + for pos in range(n_previous, n_next): + if tree.nodes[pos] is not None: + new_node = copy(tree.nodes[pos]) + if isinstance(new_node, Node): + # An internal node, we need to update the name + new_node.name = "internal.{}".format(current_pos) + new_nodes[current_pos] = new_node + else: + del tree.nodes[pos] + current_pos += 1 + n_previous = n_next + n_next = n_previous + int(self.d ** level) + current_pos = n_next + + # reset max_node, next time we add a node it will find the next + # empty position + self.max_node = 2 + + # TODO: do we want to return a new tree, or merge into this one? + self.nodes = new_nodes + return self + class Node(object): "Internal node of SBT." @@ -374,6 +410,9 @@ def load(info, dirname): new_node = Node(info['factory'], name=info['name'], fullpath=filename) return new_node + def update(self, parent): + parent.data.update(self.data) + class Leaf(object): def __init__(self, metadata, data=None, name=None, fullpath=None): diff --git a/tests/conftest.py b/tests/conftest.py index 74428ed5f3..37ac98a848 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,3 +4,8 @@ @pytest.fixture(params=[True, False]) def track_abundance(request): return request.param + + +@pytest.fixture(params=[2, 5, 10]) +def n_children(request): + return request.param diff --git a/tests/sourmash_tst_utils.py b/tests/sourmash_tst_utils.py index 312d3e5198..7eee60c82a 100644 --- a/tests/sourmash_tst_utils.py +++ b/tests/sourmash_tst_utils.py @@ -17,6 +17,12 @@ from io import StringIO +SIG_FILES = [os.path.join('demo', f) for f in ( + "SRR2060939_1.sig", "SRR2060939_2.sig", "SRR2241509_1.sig", + "SRR2255622_1.sig", "SRR453566_1.sig", "SRR453569_1.sig", "SRR453570_1.sig") +] + + def scriptpath(scriptname='sourmash'): """Return the path to the scripts, in both dev and install situations.""" # note - it doesn't matter what the scriptname is here, as long as diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 8ed3710e7d..d4b07d2182 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -9,15 +9,9 @@ from sourmash_lib.sbtmh import SigLeaf, search_minhashes -SIG_FILES = [os.path.join('demo', f) for f in ( - "SRR2060939_1.sig", "SRR2060939_2.sig", "SRR2241509_1.sig", - "SRR2255622_1.sig", "SRR453566_1.sig", "SRR453569_1.sig", "SRR453570_1.sig") -] - - -def test_simple(): +def test_simple(n_children): factory = GraphFactory(5, 100, 3) - root = SBT(factory) + root = SBT(factory, d=n_children) leaf1 = Leaf("a", factory()) leaf1.data.count('AAAAA') @@ -74,10 +68,11 @@ def search_kmer_in_list(kmer): print([ x.metadata for x in root.find(search_kmer, "CAAAA") ]) print([ x.metadata for x in root.find(search_kmer, "GAAAA") ]) -def test_longer_search(): + +def test_longer_search(n_children): ksize = 5 factory = GraphFactory(ksize, 100, 3) - root = SBT(factory) + root = SBT(factory, d=n_children) leaf1 = Leaf("a", factory()) leaf1.data.count('AAAAA') @@ -137,7 +132,7 @@ def test_tree_v1_load(): tree_v2 = SBT.load(utils.get_test_data('v2.sbt.json'), leaf_loader=SigLeaf.load) - testdata1 = utils.get_test_data(SIG_FILES[0]) + testdata1 = utils.get_test_data(utils.SIG_FILES[0]) to_search = next(signature.load_signatures(testdata1)) results_v1 = {str(s) for s in tree_v1.find(search_minhashes, @@ -149,11 +144,11 @@ def test_tree_v1_load(): assert len(results_v1) == 4 -def test_tree_save_load(): +def test_tree_save_load(n_children): factory = GraphFactory(31, 1e5, 4) - tree = SBT(factory) + tree = SBT(factory, d=n_children) - for f in SIG_FILES: + for f in utils.SIG_FILES: sig = next(signature.load_signatures(utils.get_test_data(f))) leaf = SigLeaf(os.path.basename(f), sig) tree.add_node(leaf) @@ -161,7 +156,8 @@ def test_tree_save_load(): print('*' * 60) print("{}:".format(to_search.metadata)) - old_result = [str(s) for s in tree.find(search_minhashes, to_search.data, 0.1)] + old_result = {str(s) for s in tree.find(search_minhashes, + to_search.data, 0.1)} print(*old_result, sep='\n') with utils.TempDirectory() as location: @@ -171,8 +167,8 @@ def test_tree_save_load(): print('*' * 60) print("{}:".format(to_search.metadata)) - new_result = [str(s) for s in tree.find(search_minhashes, - to_search.data, 0.1)] + new_result = {str(s) for s in tree.find(search_minhashes, + to_search.data, 0.1)} print(*new_result, sep='\n') assert old_result == new_result @@ -185,19 +181,72 @@ def test_binary_nary_tree(): trees[5] = SBT(factory, d=5) trees[10] = SBT(factory, d=10) - for f in SIG_FILES: + n_leaves = 0 + for f in utils.SIG_FILES: sig = next(signature.load_signatures(utils.get_test_data(f))) leaf = SigLeaf(os.path.basename(f), sig) for tree in trees.values(): tree.add_node(leaf) to_search = leaf + n_leaves += 1 + + assert all([len(t.leaves()) == n_leaves for t in trees.values()]) results = {} print('*' * 60) print("{}:".format(to_search.metadata)) for d, tree in trees.items(): - results[d] = [str(s) for s in tree.find(search_minhashes, to_search.data, 0.1)] + results[d] = {str(s) for s in tree.find(search_minhashes, to_search.data, 0.1)} print(*results[2], sep='\n') - assert set(results[2]) == set(results[5]) - assert set(results[5]) == set(results[10]) + assert results[2] == results[5] + assert results[5] == results[10] + + +def test_sbt_combine(n_children): + factory = GraphFactory(31, 1e5, 4) + tree = SBT(factory, d=n_children) + tree_1 = SBT(factory, d=n_children) + tree_2 = SBT(factory, d=n_children) + + n_leaves = 0 + for f in utils.SIG_FILES: + sig = next(signature.load_signatures(utils.get_test_data(f))) + leaf = SigLeaf(os.path.basename(f), sig) + tree.add_node(leaf) + if n_leaves < 4: + tree_1.add_node(leaf) + else: + tree_2.add_node(leaf) + n_leaves += 1 + + tree_1.combine(tree_2) + + t1_leaves = {str(l) for l in tree_1.leaves()} + t_leaves = {str(l) for l in tree.leaves()} + + assert len(t1_leaves) == n_leaves + assert len(t_leaves) == len(t1_leaves) + assert t1_leaves == t_leaves + + to_search = next(signature.load_signatures( + utils.get_test_data(utils.SIG_FILES[0]))) + t1_result = {str(s) for s in tree_1.find(search_minhashes, + to_search, 0.1)} + tree_result = {str(s) for s in tree.find(search_minhashes, + to_search, 0.1)} + assert t1_result == tree_result + + # TODO: save and load both trees + + # check if adding a new node will use the next empty position + next_empty = 0 + for n, d in tree_1.nodes.items(): + if d is None: + next_empty = n + break + if not next_empty: + next_empty = n + 1 + + tree_1.add_node(leaf) + assert tree_1.max_node == next_empty diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index a1077f6f1e..11d49393b0 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -807,6 +807,44 @@ def test_do_sourmash_sbt_index_traverse(): assert testdata2 in out +def test_do_sourmash_sbt_combine(): + with utils.TempDirectory() as location: + files = [utils.get_test_data(f) for f in utils.SIG_FILES] + + status, out, err = utils.runscript('sourmash', + ['sbt_index', 'zzz'] + files, + in_directory=location) + + assert os.path.exists(os.path.join(location, 'zzz.sbt.json')) + + status, out, err = utils.runscript('sourmash', + ['sbt_combine', 'joined', + 'zzz.sbt.json', 'zzz.sbt.json'], + in_directory=location) + + assert os.path.exists(os.path.join(location, 'joined.sbt.json')) + + filename = os.path.splitext(os.path.basename(utils.SIG_FILES[0]))[0] + + status, out, err = utils.runscript('sourmash', + ['sbt_search', 'zzz'] + [files[0]], + in_directory=location) + print(out) + + assert out.count(filename) == 1 + + status, out, err = utils.runscript('sourmash', + ['sbt_search', 'joined'] + [files[0]], + in_directory=location) + print(out) + + # TODO: signature is loaded more than once, + # so checking if we get two results. + # If we ever start reporting only one match (even if appears repeated), + # change this test too! + assert out.count(filename) == 2 + + def test_do_sourmash_sbt_search_otherdir(): with utils.TempDirectory() as location: testdata1 = utils.get_test_data('short.fa')