diff --git a/sourmash_lib/commands.py b/sourmash_lib/commands.py index 50e0a12cf1..716040f680 100644 --- a/sourmash_lib/commands.py +++ b/sourmash_lib/commands.py @@ -486,6 +486,7 @@ def sbt_index(args): help='signatures to load into SBT') parser.add_argument('-k', '--ksize', type=int, default=None) parser.add_argument('--traverse-directory', action='store_true') + parser.add_argument('--append', action='store_true', default=False) parser.add_argument('-x', '--bf-size', type=float, default=1e5) sourmash_args.add_moltype_args(parser) @@ -493,8 +494,11 @@ def sbt_index(args): args = parser.parse_args(args) moltype = sourmash_args.calculate_moltype(args) - factory = GraphFactory(1, args.bf_size, 4) - tree = SBT(factory) + if args.append: + tree = SBT.load(args.sbt_name, leaf_loader=SigLeaf.load) + else: + factory = GraphFactory(1, args.bf_size, 4) + tree = SBT(factory) if args.traverse_directory: inp_files = list(sourmash_args.traverse_find_sigs(args.signatures)) diff --git a/sourmash_lib/sbt.py b/sourmash_lib/sbt.py index 08245a4584..75f35b214c 100644 --- a/sourmash_lib/sbt.py +++ b/sourmash_lib/sbt.py @@ -156,6 +156,7 @@ def parent(self, pos): if pos == 0: return None p = int(math.floor((pos - 1) / self.d)) + self.nodes[p] = self.nodes[p].do_load() return NodePos(p, self.nodes[p]) def children(self, pos): @@ -163,6 +164,8 @@ def children(self, pos): def child(self, parent, pos): cd = self.d * parent + pos + 1 + if self.nodes[cd] is not None: + self.nodes[cd] = self.nodes[cd].do_load() return NodePos(cd, self.nodes[cd]) def save(self, tag): @@ -184,6 +187,7 @@ def save(self, tag): structure[i] = None continue + node = node.do_load() basename = os.path.basename(node.name) data = { 'filename': os.path.join('.sbt.' + basetag, diff --git a/sourmash_lib/test_sourmash.py b/sourmash_lib/test_sourmash.py index 524e7b4170..801e94fc5d 100644 --- a/sourmash_lib/test_sourmash.py +++ b/sourmash_lib/test_sourmash.py @@ -693,6 +693,52 @@ def test_do_sourmash_sbt_index_traverse(): assert testdata2 in out +def test_do_sourmash_sbt_index_append(): + with utils.TempDirectory() as location: + testdata1 = utils.get_test_data('short.fa') + testdata2 = utils.get_test_data('short2.fa') + testdata3 = utils.get_test_data('short3.fa') + status, out, err = utils.runscript('sourmash', + ['compute', testdata1, testdata2, testdata3], + in_directory=location) + + status, out, err = utils.runscript('sourmash', + ['sbt_index', 'zzz', + 'short.fa.sig', 'short2.fa.sig'], + in_directory=location) + + assert os.path.exists(os.path.join(location, 'zzz.sbt.json')) + + sbt_name = os.path.join(location, 'zzz',) + sig_loc = os.path.join(location, 'short3.fa.sig') + status, out, err = utils.runscript('sourmash', + ['sbt_search', sbt_name, sig_loc]) + print(out) + + assert testdata1 in out + assert testdata2 in out + assert testdata3 not in out + + status, out, err = utils.runscript('sourmash', + ['sbt_index', '--append', + 'zzz', + 'short3.fa.sig'], + in_directory=location) + + assert os.path.exists(os.path.join(location, 'zzz.sbt.json')) + + sbt_name = os.path.join(location, 'zzz',) + sig_loc = os.path.join(location, 'short3.fa.sig') + status, out, err = utils.runscript('sourmash', + ['sbt_search', '--threshold', '0.95', + sbt_name, sig_loc]) + print(out) + + assert testdata1 not in out + assert testdata2 in out + assert testdata3 in out + + def test_do_sourmash_sbt_search_otherdir(): with utils.TempDirectory() as location: testdata1 = utils.get_test_data('short.fa')