From a2db3eb9498eda8a57e305bd710290c0066dcc03 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Tue, 7 Mar 2017 16:31:58 -0800 Subject: [PATCH 01/12] Use defaultdict instead of list for SBT, fix bug during insertion --- sourmash_lib/sbt.py | 43 +++++++++++++++++++++++-------------------- tests/test_sbt.py | 4 ++++ 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 08245a4584..e0e1de5f05 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -47,7 +47,7 @@ def search_transcript(node, seq, threshold): from __future__ import print_function, unicode_literals, division -from collections import namedtuple, Mapping +from collections import namedtuple, Mapping, defaultdict import hashlib import json import math @@ -78,19 +78,14 @@ class SBT(object): def __init__(self, factory, d=2): self.factory = factory - self.nodes = [None] + self.nodes = defaultdict(lambda: None) self.d = d def new_node_pos(self, node): - try: - pos = self.nodes.index(None) - except ValueError: - # There aren't any empty positions left. - # Extend array - height = math.floor(math.log(len(self.nodes), self.d)) + 1 - self.nodes += [None] * int(self.d ** height) - pos = self.nodes.index(None) - return pos + if self.nodes: + return max(self.nodes) + 1 + else: + return 0 def add_node(self, node): pos = self.new_node_pos(node) @@ -106,6 +101,8 @@ def add_node(self, node): # - add Leaf, update parent # 3) parent is a Node (no position available) # - this is covered by case 1 + # 4) parent is None + # this can happen with d != 2, in this case create the parent node p = self.parent(pos) if isinstance(p.node, Leaf): # Create a new internal node @@ -123,6 +120,12 @@ def add_node(self, node): elif isinstance(p.node, Node): self.nodes[pos] = node node.update(p.node) + elif p.node is None: + n = Node(self.factory, name="internal." + str(p.pos)) + self.nodes[p.pos] = n + c1 = self.children(p.pos)[0] + self.nodes[c1.pos] = node + node.update(n) # update all parents! p = self.parent(p.pos) @@ -243,15 +246,14 @@ def _load_v1(jnodes, leaf_loader, dirname): # TODO error! raise ValueError("Empty tree!") - sbt_nodes = [] + sbt_nodes = defaultdict(lambda: None) sample_bf = os.path.join(dirname, jnodes[0]['filename']) ksize, tablesize, ntables = khmer.extract_nodegraph_info(sample_bf)[:3] factory = GraphFactory(ksize, tablesize, ntables) - for jnode in jnodes: + for i, jnode in enumerate(jnodes): if jnode is None: - sbt_nodes.append(None) continue if 'internal' in jnode['filename']: @@ -260,7 +262,7 @@ def _load_v1(jnodes, leaf_loader, dirname): else: sbt_node = leaf_loader(jnode, dirname) - sbt_nodes.append(sbt_node) + sbt_nodes[i] = sbt_node tree = SBT(factory) tree.nodes = sbt_nodes @@ -274,15 +276,14 @@ def _load_v2(cls, info, leaf_loader, dirname): if nodes[0] is None: raise ValueError("Empty tree!") - sbt_nodes = [] + sbt_nodes = defaultdict(lambda: None) sample_bf = os.path.join(dirname, nodes[0]['filename']) k, size, ntables = khmer.extract_nodegraph_info(sample_bf)[:3] factory = GraphFactory(k, size, ntables) - for i, node in sorted(nodes.items()): + for k, node in nodes.items(): if node is None: - sbt_nodes.append(None) continue if 'internal' in node['filename']: @@ -291,7 +292,7 @@ def _load_v2(cls, info, leaf_loader, dirname): else: sbt_node = leaf_loader(node, dirname) - sbt_nodes.append(sbt_node) + sbt_nodes[k] = sbt_node tree = cls(factory, d=info['d']) tree.nodes = sbt_nodes @@ -334,9 +335,11 @@ def print(self): if c.pos not in visited) def __iter__(self): - for i, node in enumerate(self.nodes): + for i, node in self.nodes.items(): yield (i, node) + def leaves(self): + return [c for c in self.nodes.values() if isinstance(c, Leaf)] class Node(object): "Internal node of SBT." diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 8ed3710e7d..4709381a48 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -185,12 +185,16 @@ def test_binary_nary_tree(): trees[5] = SBT(factory, d=5) trees[10] = SBT(factory, d=10) + n_leaves = 0 for f in SIG_FILES: sig = next(signature.load_signatures(utils.get_test_data(f))) leaf = SigLeaf(os.path.basename(f), sig) for tree in trees.values(): tree.add_node(leaf) to_search = leaf + n_leaves += 1 + + assert all([len(t.leaves()) == n_leaves for t in trees.values()]) results = {} print('*' * 60) From 93ee1b808c9950516e91fe995bc8d6ab350e430b Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Wed, 8 Mar 2017 12:49:24 -0800 Subject: [PATCH 02/12] Fix node position calculation --- sourmash_lib/sbt.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index e0e1de5f05..5211de1927 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -80,12 +80,13 @@ def __init__(self, factory, d=2): self.factory = factory self.nodes = defaultdict(lambda: None) self.d = d + self.max_node = 0 def new_node_pos(self, node): - if self.nodes: - return max(self.nodes) + 1 - else: - return 0 + while self.nodes[self.max_node] is not None: + self.max_node += 1 + next_node = self.max_node + return next_node def add_node(self, node): pos = self.new_node_pos(node) From 2a4258d3c10f5416c25d465f8143678e71f32212 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Wed, 8 Mar 2017 12:49:50 -0800 Subject: [PATCH 03/12] Add a dumb combine method, and a test --- sourmash_lib/sbt.py | 10 ++++++++++ tests/test_sbt.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 5211de1927..88bc88a12d 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -342,6 +342,16 @@ def __iter__(self): def leaves(self): return [c for c in self.nodes.values() if isinstance(c, Leaf)] + def combine(self, other): + # TODO: first pass, the dumb way: + # 1) find all leaves in other + # 2) add all leaves in other to self + # Why is is dumb? Because we already have all the internal nodes + # ready in other, so instead we can reuse them. + for leaf in other.leaves(): + self.add_node(leaf) + + class Node(object): "Internal node of SBT." diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 4709381a48..129bf84256 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -74,6 +74,7 @@ def search_kmer_in_list(kmer): print([ x.metadata for x in root.find(search_kmer, "CAAAA") ]) print([ x.metadata for x in root.find(search_kmer, "GAAAA") ]) + def test_longer_search(): ksize = 5 factory = GraphFactory(ksize, 100, 3) @@ -205,3 +206,30 @@ def test_binary_nary_tree(): assert set(results[2]) == set(results[5]) assert set(results[5]) == set(results[10]) + + +def test_sbt_combine(): + factory = GraphFactory(31, 1e5, 4) + tree = SBT(factory) + tree_1 = SBT(factory) + tree_2 = SBT(factory) + + n_leaves = 0 + for f in SIG_FILES: + sig = next(signature.load_signatures(utils.get_test_data(f))) + leaf = SigLeaf(os.path.basename(f), sig) + tree.add_node(leaf) + if n_leaves < 4: + tree_1.add_node(leaf) + else: + tree_2.add_node(leaf) + n_leaves += 1 + + tree_1.combine(tree_2) + + t1_leaves = {str(l) for l in tree_1.leaves()} + t_leaves = {str(l) for l in tree.leaves()} + + assert len(t1_leaves) == n_leaves + assert len(t_leaves) == len(t1_leaves) + assert t1_leaves == t_leaves From 18dc09b1a7f7256b70d7c024685a13fabce5668f Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Wed, 8 Mar 2017 21:42:02 -0800 Subject: [PATCH 04/12] Update combine implementation, still failing for d>2 --- sourmash_lib/sbt.py | 43 +++++++++++++++++++++++++++++++------------ tests/conftest.py | 5 +++++ tests/test_sbt.py | 29 +++++++++++++++++++---------- 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 88bc88a12d..5bf8edd477 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -48,16 +48,12 @@ def search_transcript(node, seq, threshold): from __future__ import print_function, unicode_literals, division from collections import namedtuple, Mapping, defaultdict -import hashlib +from copy import copy import json import math import os -import random -import shutil -from tempfile import NamedTemporaryFile import khmer -from khmer import khmer_args from random import randint from numpy import array @@ -343,13 +339,33 @@ def leaves(self): return [c for c in self.nodes.values() if isinstance(c, Leaf)] def combine(self, other): - # TODO: first pass, the dumb way: - # 1) find all leaves in other - # 2) add all leaves in other to self - # Why is is dumb? Because we already have all the internal nodes - # ready in other, so instead we can reuse them. - for leaf in other.leaves(): - self.add_node(leaf) + larger, smaller = self, other + if len(other.nodes) > len(self.nodes): + larger, smaller = other, self + + n = Node(self.factory, name="internal.0") + larger.nodes[0].update(n) + smaller.nodes[0].update(n) + new_nodes = defaultdict(lambda: None) + new_nodes[0] = n + + levels = int(math.ceil(math.log(len(larger.nodes), self.d))) + 1 + current_pos = 1 + for level in range(1, levels): + for tree in (larger, smaller): + for pos in range(int(self.d ** (level - 1)), + int(self.d ** level)): + if tree.nodes[pos - 1] is not None: + new_node = copy(tree.nodes[pos - 1]) + if isinstance(new_node, Node): + # An internal node, we need to update the name + new_node.name = "internal.{}".format(current_pos) + new_nodes[current_pos] = new_node + current_pos += 1 + + # TODO: do we want to return a new tree, or merge into this one? + self.nodes = new_nodes + return self class Node(object): @@ -388,6 +404,9 @@ def load(info, dirname): new_node = Node(info['factory'], name=info['name'], fullpath=filename) return new_node + def update(self, parent): + parent.data.update(self.data) + class Leaf(object): def __init__(self, metadata, data=None, name=None, fullpath=None): diff --git a/tests/conftest.py b/tests/conftest.py index 74428ed5f3..37ac98a848 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,3 +4,8 @@ @pytest.fixture(params=[True, False]) def track_abundance(request): return request.param + + +@pytest.fixture(params=[2, 5, 10]) +def n_children(request): + return request.param diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 129bf84256..9e3488e08d 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -15,9 +15,9 @@ ] -def test_simple(): +def test_simple(n_children): factory = GraphFactory(5, 100, 3) - root = SBT(factory) + root = SBT(factory, d=n_children) leaf1 = Leaf("a", factory()) leaf1.data.count('AAAAA') @@ -75,10 +75,10 @@ def search_kmer_in_list(kmer): print([ x.metadata for x in root.find(search_kmer, "GAAAA") ]) -def test_longer_search(): +def test_longer_search(n_children): ksize = 5 factory = GraphFactory(ksize, 100, 3) - root = SBT(factory) + root = SBT(factory, d=n_children) leaf1 = Leaf("a", factory()) leaf1.data.count('AAAAA') @@ -150,9 +150,9 @@ def test_tree_v1_load(): assert len(results_v1) == 4 -def test_tree_save_load(): +def test_tree_save_load(n_children): factory = GraphFactory(31, 1e5, 4) - tree = SBT(factory) + tree = SBT(factory, d=n_children) for f in SIG_FILES: sig = next(signature.load_signatures(utils.get_test_data(f))) @@ -208,11 +208,11 @@ def test_binary_nary_tree(): assert set(results[5]) == set(results[10]) -def test_sbt_combine(): +def test_sbt_combine(n_children): factory = GraphFactory(31, 1e5, 4) - tree = SBT(factory) - tree_1 = SBT(factory) - tree_2 = SBT(factory) + tree = SBT(factory, d=n_children) + tree_1 = SBT(factory, d=n_children) + tree_2 = SBT(factory, d=n_children) n_leaves = 0 for f in SIG_FILES: @@ -233,3 +233,12 @@ def test_sbt_combine(): assert len(t1_leaves) == n_leaves assert len(t_leaves) == len(t1_leaves) assert t1_leaves == t_leaves + + to_search = next(signature.load_signatures(utils.get_test_data(SIG_FILES[0]))) + t1_result = [str(s) for s in tree_1.find(search_minhashes, + to_search, 0.1)] + tree_result = [str(s) for s in tree.find(search_minhashes, + to_search, 0.1)] + assert t1_result == tree_result + + # TODO: save and load both trees From e1ffac1b0332426f364c8882e865fec9181228e8 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Thu, 9 Mar 2017 19:01:17 -0800 Subject: [PATCH 05/12] Simplify print_dot; use sets in tests --- sourmash_lib/sbt.py | 17 +++++++---------- tests/test_sbt.py | 21 +++++++++++---------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 5bf8edd477..00b34cd553 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -306,16 +306,13 @@ def print_dot(self): edge [arrowsize=0.8]; """) - for i, node in iter(self): - if node is None: - continue - - p = self.parent(i) - if p is not None: - if isinstance(node, Leaf): - print('"', p.pos, '"', '->', '"', node.name, '";') - else: - print('"', p.pos, '"', '->', '"', i, '";') + for i, node in list(self.nodes.items()): + if isinstance(node, Node): + print('"{}" [shape=box fillcolor=gray style=filled]'.format( + node.name)) + for j, child in self.children(i): + if child is not None: + print('"{}" -> "{}"'.format(node.name, child.name)) print("}") def print(self): diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 9e3488e08d..6a3ae248ca 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -162,7 +162,8 @@ def test_tree_save_load(n_children): print('*' * 60) print("{}:".format(to_search.metadata)) - old_result = [str(s) for s in tree.find(search_minhashes, to_search.data, 0.1)] + old_result = {str(s) for s in tree.find(search_minhashes, + to_search.data, 0.1)} print(*old_result, sep='\n') with utils.TempDirectory() as location: @@ -172,8 +173,8 @@ def test_tree_save_load(n_children): print('*' * 60) print("{}:".format(to_search.metadata)) - new_result = [str(s) for s in tree.find(search_minhashes, - to_search.data, 0.1)] + new_result = {str(s) for s in tree.find(search_minhashes, + to_search.data, 0.1)} print(*new_result, sep='\n') assert old_result == new_result @@ -201,11 +202,11 @@ def test_binary_nary_tree(): print('*' * 60) print("{}:".format(to_search.metadata)) for d, tree in trees.items(): - results[d] = [str(s) for s in tree.find(search_minhashes, to_search.data, 0.1)] + results[d] = {str(s) for s in tree.find(search_minhashes, to_search.data, 0.1)} print(*results[2], sep='\n') - assert set(results[2]) == set(results[5]) - assert set(results[5]) == set(results[10]) + assert results[2] == results[5] + assert results[5] == results[10] def test_sbt_combine(n_children): @@ -235,10 +236,10 @@ def test_sbt_combine(n_children): assert t1_leaves == t_leaves to_search = next(signature.load_signatures(utils.get_test_data(SIG_FILES[0]))) - t1_result = [str(s) for s in tree_1.find(search_minhashes, - to_search, 0.1)] - tree_result = [str(s) for s in tree.find(search_minhashes, - to_search, 0.1)] + t1_result = {str(s) for s in tree_1.find(search_minhashes, + to_search, 0.1)} + tree_result = {str(s) for s in tree.find(search_minhashes, + to_search, 0.1)} assert t1_result == tree_result # TODO: save and load both trees From 6711dc6a7a74509a0fa527e9262929cd288bb1dc Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Thu, 9 Mar 2017 20:24:48 -0800 Subject: [PATCH 06/12] Working on combine --- sourmash_lib/sbt.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 00b34cd553..152a37c38e 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -348,17 +348,20 @@ def combine(self, other): levels = int(math.ceil(math.log(len(larger.nodes), self.d))) + 1 current_pos = 1 - for level in range(1, levels): + n_previous = 0 + n_next = 1 + for level in range(1, levels + 1): for tree in (larger, smaller): - for pos in range(int(self.d ** (level - 1)), - int(self.d ** level)): - if tree.nodes[pos - 1] is not None: - new_node = copy(tree.nodes[pos - 1]) + for pos in range(n_previous, n_next): + if tree.nodes[pos] is not None: + new_node = copy(tree.nodes[pos]) if isinstance(new_node, Node): # An internal node, we need to update the name new_node.name = "internal.{}".format(current_pos) new_nodes[current_pos] = new_node current_pos += 1 + n_previous = n_next + n_next = n_previous + int(self.d ** level) # TODO: do we want to return a new tree, or merge into this one? self.nodes = new_nodes From fa1203a1e5d30600f9f433667f00613b43002f63 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Thu, 9 Mar 2017 21:20:02 -0800 Subject: [PATCH 07/12] fix combine, working now --- sourmash_lib/sbt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 152a37c38e..ddd159ee3d 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -362,6 +362,7 @@ def combine(self, other): current_pos += 1 n_previous = n_next n_next = n_previous + int(self.d ** level) + current_pos = n_next # TODO: do we want to return a new tree, or merge into this one? self.nodes = new_nodes From 450c390fd65daceedd637b3f0a1d539648f6c4b8 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Fri, 10 Mar 2017 10:01:25 -0800 Subject: [PATCH 08/12] Fix max_node calculation, add a test --- sourmash_lib/sbt.py | 9 +++++++-- tests/test_sbt.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index ddd159ee3d..3dffbd7928 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -81,8 +81,7 @@ def __init__(self, factory, d=2): def new_node_pos(self, node): while self.nodes[self.max_node] is not None: self.max_node += 1 - next_node = self.max_node - return next_node + return self.max_node def add_node(self, node): pos = self.new_node_pos(node) @@ -359,11 +358,17 @@ def combine(self, other): # An internal node, we need to update the name new_node.name = "internal.{}".format(current_pos) new_nodes[current_pos] = new_node + if tree.nodes[pos] is None: + del tree.nodes[pos] current_pos += 1 n_previous = n_next n_next = n_previous + int(self.d ** level) current_pos = n_next + # reset max_node, next time we add a node it will find the next + # empty position + self.max_node = 2 + # TODO: do we want to return a new tree, or merge into this one? self.nodes = new_nodes return self diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 6a3ae248ca..b42bbf9a3f 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -243,3 +243,15 @@ def test_sbt_combine(n_children): assert t1_result == tree_result # TODO: save and load both trees + + # check if adding a new node will use the next empty position + next_empty = 0 + for n, d in tree_1.nodes.items(): + if d is None: + next_empty = n + break + if not next_empty: + next_empty = n + 1 + + tree_1.add_node(leaf) + assert tree_1.max_node == next_empty From 2c3c6f250df9fdefccb6f4454cebcfd5eb76fa71 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Fri, 10 Mar 2017 10:48:37 -0800 Subject: [PATCH 09/12] Sigh, need to do_load... --- sourmash_lib/sbt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 3dffbd7928..e28358d031 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -340,8 +340,8 @@ def combine(self, other): larger, smaller = other, self n = Node(self.factory, name="internal.0") - larger.nodes[0].update(n) - smaller.nodes[0].update(n) + larger.nodes[0].do_load().update(n) + smaller.nodes[0].do_load().update(n) new_nodes = defaultdict(lambda: None) new_nodes[0] = n From b3fea9617553403f7eb8683121dc0376104e32f4 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Fri, 10 Mar 2017 13:09:28 -0800 Subject: [PATCH 10/12] Add a missing do_load() --- sourmash_lib/sbt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index e28358d031..4d1ac423b8 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -183,6 +183,7 @@ def save(self, tag): structure[i] = None continue + node = node.do_load() basename = os.path.basename(node.name) data = { 'filename': os.path.join('.sbt.' + basetag, From c6888d7f72fde9f496faf7746a9ace9818013768 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Fri, 7 Apr 2017 23:35:00 +0000 Subject: [PATCH 11/12] Add a sbt_combine command, with tests --- sourmash_lib/__main__.py | 7 +++++-- sourmash_lib/commands.py | 29 ++++++++++++++++++++++++++++ sourmash_lib/sbt.py | 2 +- tests/sourmash_tst_utils.py | 6 ++++++ tests/test_sbt.py | 17 ++++++----------- tests/test_sourmash.py | 38 +++++++++++++++++++++++++++++++++++++ 6 files changed, 85 insertions(+), 14 deletions(-) diff --git a/sourmash_lib/__main__.py b/sourmash_lib/__main__.py index 2d32ddcbc7..edbf3b2a2c 100644 --- a/sourmash_lib/__main__.py +++ b/sourmash_lib/__main__.py @@ -8,7 +8,8 @@ from .logging import notify, error from .commands import (categorize, compare, compute, convert, dump, import_csv, - sbt_gather, sbt_index, sbt_search, search, plot, watch) + sbt_gather, sbt_index, sbt_combine, sbt_search, search, + plot, watch) def main(): @@ -17,7 +18,8 @@ def main(): 'import_csv': import_csv, 'dump': dump, 'sbt_index': sbt_index, 'sbt_search': sbt_search, 'categorize': categorize, 'sbt_gather': sbt_gather, - 'watch': watch, 'convert': convert} + 'watch': watch, 'convert': convert, + 'sbt_combine': sbt_combine} parser = argparse.ArgumentParser( description='work with RNAseq signatures', usage='''sourmash [] @@ -33,6 +35,7 @@ def main(): convert Convert signatures from YAML to JSON. sbt_index Index signatures with a Sequence Bloom Tree. +sbt_combine Combine multiple Sequence Bloom Trees into a new one. sbt_search Search a Sequence Bloom Tree. categorize Categorize signatures with a SBT. sbt_gather Search a signature for multiple matches. diff --git a/sourmash_lib/commands.py b/sourmash_lib/commands.py index ab6aa4cdc4..fd8a738252 100644 --- a/sourmash_lib/commands.py +++ b/sourmash_lib/commands.py @@ -480,6 +480,35 @@ def dump(args): fp.close() +def sbt_combine(args): + from sourmash_lib.sbt import SBT, GraphFactory + from sourmash_lib.sbtmh import SigLeaf + + parser = argparse.ArgumentParser() + parser.add_argument('sbt_name', help='name to save SBT into') + parser.add_argument('sbts', nargs='+', + help='SBTs to combine to a new SBT') + parser.add_argument('-x', '--bf-size', type=float, default=1e5) + + sourmash_args.add_moltype_args(parser) + + args = parser.parse_args(args) + moltype = sourmash_args.calculate_moltype(args) + + inp_files = list(args.sbts) + notify('combining {} SBTs', len(inp_files)) + + tree = SBT.load(inp_files.pop(0), leaf_loader=SigLeaf.load) + + for f in inp_files: + new_tree = SBT.load(f, leaf_loader=SigLeaf.load) + # TODO: check if parameters are the same for both trees! + tree.combine(new_tree) + + notify('saving SBT under "{}"', args.sbt_name) + tree.save(args.sbt_name) + + def sbt_index(args): from sourmash_lib.sbt import SBT, GraphFactory from sourmash_lib.sbtmh import search_minhashes, SigLeaf diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 4d1ac423b8..d31f66713f 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -359,7 +359,7 @@ def combine(self, other): # An internal node, we need to update the name new_node.name = "internal.{}".format(current_pos) new_nodes[current_pos] = new_node - if tree.nodes[pos] is None: + else: del tree.nodes[pos] current_pos += 1 n_previous = n_next diff --git a/tests/sourmash_tst_utils.py b/tests/sourmash_tst_utils.py index 312d3e5198..7eee60c82a 100644 --- a/tests/sourmash_tst_utils.py +++ b/tests/sourmash_tst_utils.py @@ -17,6 +17,12 @@ from io import StringIO +SIG_FILES = [os.path.join('demo', f) for f in ( + "SRR2060939_1.sig", "SRR2060939_2.sig", "SRR2241509_1.sig", + "SRR2255622_1.sig", "SRR453566_1.sig", "SRR453569_1.sig", "SRR453570_1.sig") +] + + def scriptpath(scriptname='sourmash'): """Return the path to the scripts, in both dev and install situations.""" # note - it doesn't matter what the scriptname is here, as long as diff --git a/tests/test_sbt.py b/tests/test_sbt.py index b42bbf9a3f..d4b07d2182 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -9,12 +9,6 @@ from sourmash_lib.sbtmh import SigLeaf, search_minhashes -SIG_FILES = [os.path.join('demo', f) for f in ( - "SRR2060939_1.sig", "SRR2060939_2.sig", "SRR2241509_1.sig", - "SRR2255622_1.sig", "SRR453566_1.sig", "SRR453569_1.sig", "SRR453570_1.sig") -] - - def test_simple(n_children): factory = GraphFactory(5, 100, 3) root = SBT(factory, d=n_children) @@ -138,7 +132,7 @@ def test_tree_v1_load(): tree_v2 = SBT.load(utils.get_test_data('v2.sbt.json'), leaf_loader=SigLeaf.load) - testdata1 = utils.get_test_data(SIG_FILES[0]) + testdata1 = utils.get_test_data(utils.SIG_FILES[0]) to_search = next(signature.load_signatures(testdata1)) results_v1 = {str(s) for s in tree_v1.find(search_minhashes, @@ -154,7 +148,7 @@ def test_tree_save_load(n_children): factory = GraphFactory(31, 1e5, 4) tree = SBT(factory, d=n_children) - for f in SIG_FILES: + for f in utils.SIG_FILES: sig = next(signature.load_signatures(utils.get_test_data(f))) leaf = SigLeaf(os.path.basename(f), sig) tree.add_node(leaf) @@ -188,7 +182,7 @@ def test_binary_nary_tree(): trees[10] = SBT(factory, d=10) n_leaves = 0 - for f in SIG_FILES: + for f in utils.SIG_FILES: sig = next(signature.load_signatures(utils.get_test_data(f))) leaf = SigLeaf(os.path.basename(f), sig) for tree in trees.values(): @@ -216,7 +210,7 @@ def test_sbt_combine(n_children): tree_2 = SBT(factory, d=n_children) n_leaves = 0 - for f in SIG_FILES: + for f in utils.SIG_FILES: sig = next(signature.load_signatures(utils.get_test_data(f))) leaf = SigLeaf(os.path.basename(f), sig) tree.add_node(leaf) @@ -235,7 +229,8 @@ def test_sbt_combine(n_children): assert len(t_leaves) == len(t1_leaves) assert t1_leaves == t_leaves - to_search = next(signature.load_signatures(utils.get_test_data(SIG_FILES[0]))) + to_search = next(signature.load_signatures( + utils.get_test_data(utils.SIG_FILES[0]))) t1_result = {str(s) for s in tree_1.find(search_minhashes, to_search, 0.1)} tree_result = {str(s) for s in tree.find(search_minhashes, diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index a1077f6f1e..11d49393b0 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -807,6 +807,44 @@ def test_do_sourmash_sbt_index_traverse(): assert testdata2 in out +def test_do_sourmash_sbt_combine(): + with utils.TempDirectory() as location: + files = [utils.get_test_data(f) for f in utils.SIG_FILES] + + status, out, err = utils.runscript('sourmash', + ['sbt_index', 'zzz'] + files, + in_directory=location) + + assert os.path.exists(os.path.join(location, 'zzz.sbt.json')) + + status, out, err = utils.runscript('sourmash', + ['sbt_combine', 'joined', + 'zzz.sbt.json', 'zzz.sbt.json'], + in_directory=location) + + assert os.path.exists(os.path.join(location, 'joined.sbt.json')) + + filename = os.path.splitext(os.path.basename(utils.SIG_FILES[0]))[0] + + status, out, err = utils.runscript('sourmash', + ['sbt_search', 'zzz'] + [files[0]], + in_directory=location) + print(out) + + assert out.count(filename) == 1 + + status, out, err = utils.runscript('sourmash', + ['sbt_search', 'joined'] + [files[0]], + in_directory=location) + print(out) + + # TODO: signature is loaded more than once, + # so checking if we get two results. + # If we ever start reporting only one match (even if appears repeated), + # change this test too! + assert out.count(filename) == 2 + + def test_do_sourmash_sbt_search_otherdir(): with utils.TempDirectory() as location: testdata1 = utils.get_test_data('short.fa') From f16930aef96308b13737b5c73d47dba996a7e68a Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Mon, 10 Apr 2017 20:13:18 +0000 Subject: [PATCH 12/12] Remove do_load --- sourmash_lib/sbt.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index d31f66713f..abc1f35734 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -183,7 +183,6 @@ def save(self, tag): structure[i] = None continue - node = node.do_load() basename = os.path.basename(node.name) data = { 'filename': os.path.join('.sbt.' + basetag, @@ -341,8 +340,8 @@ def combine(self, other): larger, smaller = other, self n = Node(self.factory, name="internal.0") - larger.nodes[0].do_load().update(n) - smaller.nodes[0].do_load().update(n) + larger.nodes[0].update(n) + smaller.nodes[0].update(n) new_nodes = defaultdict(lambda: None) new_nodes[0] = n