Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] SBT feature: combine trees #143

Merged
merged 12 commits into from
Apr 10, 2017
7 changes: 5 additions & 2 deletions sourmash_lib/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from .logging import notify, error

from .commands import (categorize, compare, compute, convert, dump, import_csv,
sbt_gather, sbt_index, sbt_search, search, plot, watch)
sbt_gather, sbt_index, sbt_combine, sbt_search, search,
plot, watch)


def main():
Expand All @@ -17,7 +18,8 @@ def main():
'import_csv': import_csv, 'dump': dump,
'sbt_index': sbt_index, 'sbt_search': sbt_search,
'categorize': categorize, 'sbt_gather': sbt_gather,
'watch': watch, 'convert': convert}
'watch': watch, 'convert': convert,
'sbt_combine': sbt_combine}
parser = argparse.ArgumentParser(
description='work with RNAseq signatures',
usage='''sourmash <command> [<args>]
Expand All @@ -33,6 +35,7 @@ def main():
convert Convert signatures from YAML to JSON.

sbt_index Index signatures with a Sequence Bloom Tree.
sbt_combine Combine multiple Sequence Bloom Trees into a new one.
sbt_search Search a Sequence Bloom Tree.
categorize Categorize signatures with a SBT.
sbt_gather Search a signature for multiple matches.
Expand Down
29 changes: 29 additions & 0 deletions sourmash_lib/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,35 @@ def dump(args):
fp.close()


def sbt_combine(args):
from sourmash_lib.sbt import SBT, GraphFactory
from sourmash_lib.sbtmh import SigLeaf

parser = argparse.ArgumentParser()
parser.add_argument('sbt_name', help='name to save SBT into')
parser.add_argument('sbts', nargs='+',
help='SBTs to combine to a new SBT')
parser.add_argument('-x', '--bf-size', type=float, default=1e5)

sourmash_args.add_moltype_args(parser)

args = parser.parse_args(args)
moltype = sourmash_args.calculate_moltype(args)

inp_files = list(args.sbts)
notify('combining {} SBTs', len(inp_files))

tree = SBT.load(inp_files.pop(0), leaf_loader=SigLeaf.load)

for f in inp_files:
new_tree = SBT.load(f, leaf_loader=SigLeaf.load)
# TODO: check if parameters are the same for both trees!
tree.combine(new_tree)

notify('saving SBT under "{}"', args.sbt_name)
tree.save(args.sbt_name)


def sbt_index(args):
from sourmash_lib.sbt import SBT, GraphFactory
from sourmash_lib.sbtmh import search_minhashes, SigLeaf
Expand Down
109 changes: 74 additions & 35 deletions sourmash_lib/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,13 @@ def search_transcript(node, seq, threshold):

from __future__ import print_function, unicode_literals, division

from collections import namedtuple, Mapping
import hashlib
from collections import namedtuple, Mapping, defaultdict
from copy import copy
import json
import math
import os
import random
import shutil
from tempfile import NamedTemporaryFile

import khmer
from khmer import khmer_args
from random import randint
from numpy import array

Expand All @@ -78,19 +74,14 @@ class SBT(object):

def __init__(self, factory, d=2):
self.factory = factory
self.nodes = [None]
self.nodes = defaultdict(lambda: None)
self.d = d
self.max_node = 0

def new_node_pos(self, node):
try:
pos = self.nodes.index(None)
except ValueError:
# There aren't any empty positions left.
# Extend array
height = math.floor(math.log(len(self.nodes), self.d)) + 1
self.nodes += [None] * int(self.d ** height)
pos = self.nodes.index(None)
return pos
while self.nodes[self.max_node] is not None:
self.max_node += 1
return self.max_node

def add_node(self, node):
pos = self.new_node_pos(node)
Expand All @@ -106,6 +97,8 @@ def add_node(self, node):
# - add Leaf, update parent
# 3) parent is a Node (no position available)
# - this is covered by case 1
# 4) parent is None
# this can happen with d != 2, in this case create the parent node
p = self.parent(pos)
if isinstance(p.node, Leaf):
# Create a new internal node
Expand All @@ -123,6 +116,12 @@ def add_node(self, node):
elif isinstance(p.node, Node):
self.nodes[pos] = node
node.update(p.node)
elif p.node is None:
n = Node(self.factory, name="internal." + str(p.pos))
self.nodes[p.pos] = n
c1 = self.children(p.pos)[0]
self.nodes[c1.pos] = node
node.update(n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to add a test to get some coverage here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw this happen when using bigger SBTs, I'll try to cook up a test without putting too much data in the repo


# update all parents!
p = self.parent(p.pos)
Expand Down Expand Up @@ -243,15 +242,14 @@ def _load_v1(jnodes, leaf_loader, dirname):
# TODO error!
raise ValueError("Empty tree!")

sbt_nodes = []
sbt_nodes = defaultdict(lambda: None)

sample_bf = os.path.join(dirname, jnodes[0]['filename'])
ksize, tablesize, ntables = khmer.extract_nodegraph_info(sample_bf)[:3]
factory = GraphFactory(ksize, tablesize, ntables)

for jnode in jnodes:
for i, jnode in enumerate(jnodes):
if jnode is None:
sbt_nodes.append(None)
continue

if 'internal' in jnode['filename']:
Expand All @@ -260,7 +258,7 @@ def _load_v1(jnodes, leaf_loader, dirname):
else:
sbt_node = leaf_loader(jnode, dirname)

sbt_nodes.append(sbt_node)
sbt_nodes[i] = sbt_node

tree = SBT(factory)
tree.nodes = sbt_nodes
Expand All @@ -274,15 +272,14 @@ def _load_v2(cls, info, leaf_loader, dirname):
if nodes[0] is None:
raise ValueError("Empty tree!")

sbt_nodes = []
sbt_nodes = defaultdict(lambda: None)

sample_bf = os.path.join(dirname, nodes[0]['filename'])
k, size, ntables = khmer.extract_nodegraph_info(sample_bf)[:3]
factory = GraphFactory(k, size, ntables)

for i, node in sorted(nodes.items()):
for k, node in nodes.items():
if node is None:
sbt_nodes.append(None)
continue

if 'internal' in node['filename']:
Expand All @@ -291,7 +288,7 @@ def _load_v2(cls, info, leaf_loader, dirname):
else:
sbt_node = leaf_loader(node, dirname)

sbt_nodes.append(sbt_node)
sbt_nodes[k] = sbt_node

tree = cls(factory, d=info['d'])
tree.nodes = sbt_nodes
Expand All @@ -308,16 +305,13 @@ def print_dot(self):
edge [arrowsize=0.8];
""")

for i, node in iter(self):
if node is None:
continue

p = self.parent(i)
if p is not None:
if isinstance(node, Leaf):
print('"', p.pos, '"', '->', '"', node.name, '";')
else:
print('"', p.pos, '"', '->', '"', i, '";')
for i, node in list(self.nodes.items()):
if isinstance(node, Node):
print('"{}" [shape=box fillcolor=gray style=filled]'.format(
node.name))
for j, child in self.children(i):
if child is not None:
print('"{}" -> "{}"'.format(node.name, child.name))
print("}")

def print(self):
Expand All @@ -334,9 +328,51 @@ def print(self):
if c.pos not in visited)

def __iter__(self):
for i, node in enumerate(self.nodes):
for i, node in self.nodes.items():
yield (i, node)

def leaves(self):
return [c for c in self.nodes.values() if isinstance(c, Leaf)]

def combine(self, other):
larger, smaller = self, other
if len(other.nodes) > len(self.nodes):
larger, smaller = other, self

n = Node(self.factory, name="internal.0")
larger.nodes[0].update(n)
smaller.nodes[0].update(n)
new_nodes = defaultdict(lambda: None)
new_nodes[0] = n

levels = int(math.ceil(math.log(len(larger.nodes), self.d))) + 1
current_pos = 1
n_previous = 0
n_next = 1
for level in range(1, levels + 1):
for tree in (larger, smaller):
for pos in range(n_previous, n_next):
if tree.nodes[pos] is not None:
new_node = copy(tree.nodes[pos])
if isinstance(new_node, Node):
# An internal node, we need to update the name
new_node.name = "internal.{}".format(current_pos)
new_nodes[current_pos] = new_node
else:
del tree.nodes[pos]
current_pos += 1
n_previous = n_next
n_next = n_previous + int(self.d ** level)
current_pos = n_next

# reset max_node, next time we add a node it will find the next
# empty position
self.max_node = 2

# TODO: do we want to return a new tree, or merge into this one?
self.nodes = new_nodes
return self


class Node(object):
"Internal node of SBT."
Expand Down Expand Up @@ -374,6 +410,9 @@ def load(info, dirname):
new_node = Node(info['factory'], name=info['name'], fullpath=filename)
return new_node

def update(self, parent):
parent.data.update(self.data)


class Leaf(object):
def __init__(self, metadata, data=None, name=None, fullpath=None):
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
@pytest.fixture(params=[True, False])
def track_abundance(request):
return request.param


@pytest.fixture(params=[2, 5, 10])
def n_children(request):
return request.param
6 changes: 6 additions & 0 deletions tests/sourmash_tst_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from io import StringIO


SIG_FILES = [os.path.join('demo', f) for f in (
"SRR2060939_1.sig", "SRR2060939_2.sig", "SRR2241509_1.sig",
"SRR2255622_1.sig", "SRR453566_1.sig", "SRR453569_1.sig", "SRR453570_1.sig")
]


def scriptpath(scriptname='sourmash'):
"""Return the path to the scripts, in both dev and install situations."""
# note - it doesn't matter what the scriptname is here, as long as
Expand Down
Loading