diff --git a/sourmash/commands.py b/sourmash/commands.py index d6757f4e4a..ecfce4c6d1 100644 --- a/sourmash/commands.py +++ b/sourmash/commands.py @@ -704,7 +704,8 @@ def index(args): parser.add_argument('-s', '--sparseness', type=float, default=.0, help='What percentage of internal nodes will not be saved. ' 'Ranges from 0.0 (save all nodes) to 1.0 (no nodes saved)') - + parser.add_argument('--scaled', type=float, default=0, + help='downsample signatures to this scaled factor') sourmash_args.add_moltype_args(parser) args = parser.parse_args(args) @@ -725,6 +726,10 @@ def index(args): if args.sparseness < 0 or args.sparseness > 1.0: error('sparseness must be in range [0.0, 1.0].') + if args.scaled: + args.scaled = int(args.scaled) + notify('downsampling signatures to scaled={}', args.scaled) + notify('loading {} files into SBT', len(inp_files)) n = 0 @@ -733,6 +738,7 @@ def index(args): nums = set() scaleds = set() for f in inp_files: + notify('\r...reading from {} ({} signatures so far)', f, n, end='') siglist = sig.load_signatures(f, ksize=args.ksize, select_moltype=moltype) @@ -742,6 +748,9 @@ def index(args): ksizes.add(ss.minhash.ksize) moltypes.add(sourmash_args.get_moltype(ss)) nums.add(ss.minhash.num) + + if args.scaled: + ss.minhash = ss.minhash.downsample_scaled(args.scaled) scaleds.add(ss.minhash.scaled) leaf = SigLeaf(ss.md5sum(), ss) @@ -768,6 +777,7 @@ def index(args): error('nums = {}; scaleds = {}', repr(nums), repr(scaleds)) sys.exit(-1) + notify('') # did we load any!? if n == 0: @@ -835,8 +845,8 @@ def search(args): # set up the search databases databases = sourmash_args.load_dbs_and_sigs(args.databases, query, - not args.containment, - args.traverse_directory) + not args.containment, + args.traverse_directory) if not len(databases): error('Nothing found to search!') diff --git a/sourmash/search.py b/sourmash/search.py index 16eca15847..fa6663c6e5 100644 --- a/sourmash/search.py +++ b/sourmash/search.py @@ -46,7 +46,20 @@ def search_databases(query, databases, threshold, do_containment, best_only, search_fn = SearchMinHashesFindBest().search tree = obj - for leaf in tree.find(search_fn, query, threshold): + + # figure out scaled value of tree, downsample query if needed. + leaf = next(iter(tree.leaves())) + tree_mh = leaf.data.minhash + + tree_query = query + if tree_mh.scaled and query.minhash.scaled and \ + tree_mh.scaled > query.minhash.scaled: + resampled_query_mh = tree_query.minhash + resampled_query_mh = resampled_query_mh.downsample_scaled(tree_mh.scaled) + tree_query = SourmashSignature(resampled_query_mh) + + # now, search! + for leaf in tree.find(search_fn, tree_query, threshold): similarity = query_match(leaf.data) # tree search should always/only return matches above threshold diff --git a/sourmash/sourmash_args.py b/sourmash/sourmash_args.py index 02e9030fbe..9e10cbeb95 100644 --- a/sourmash/sourmash_args.py +++ b/sourmash/sourmash_args.py @@ -190,6 +190,7 @@ def check_signatures_are_compatible(query, subject): def check_tree_is_compatible(treename, tree, query, is_similarity_query): + # get a minhash from the tree leaf = next(iter(tree.leaves())) tree_mh = leaf.data.minhash diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index fa8f6a55a6..40121f19bd 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -1328,6 +1328,45 @@ def test_do_sourmash_index_multiscaled_fail(): assert 'trying to build an SBT with incompatible signatures.' in err +@utils.in_tempdir +def test_do_sourmash_index_multiscaled_rescale(c): + # test sourmash index --scaled + testdata1 = utils.get_test_data('short.fa') + testdata2 = utils.get_test_data('short2.fa') + + c.run_sourmash('compute', '--scaled', '10', testdata1) + c.run_sourmash('compute', '--scaled', '1', testdata2) + + c.run_sourmash('index', '-k', '31', 'zzz', + '--scaled', '10', + 'short.fa.sig', + 'short2.fa.sig') + + print(c) + assert c.last_result.status == 0 + + +@utils.in_tempdir +def test_do_sourmash_index_multiscaled_rescale_fail(c): + # test sourmash index --scaled with invalid rescaling (10 -> 5) + testdata1 = utils.get_test_data('short.fa') + testdata2 = utils.get_test_data('short2.fa') + + c.run_sourmash('compute', '--scaled', '10', testdata1) + c.run_sourmash('compute', '--scaled', '1', testdata2) + # this should fail: cannot go from a scaled value of 10 to 5 + + with pytest.raises(ValueError) as e: + c.run_sourmash('index', '-k', '31', 'zzz', + '--scaled', '5', + 'short.fa.sig', + 'short2.fa.sig') + + print(e.value) + assert c.last_result.status == -1 + assert 'new scaled 5 is lower than current sample scaled 10' in c.last_result.err + + def test_do_sourmash_sbt_search_output(): with utils.TempDirectory() as location: testdata1 = utils.get_test_data('short.fa') @@ -1691,6 +1730,9 @@ def test_search_metagenome_traverse(): assert '13 matches; showing first 3:' in out +# explanation: you cannot downsample a scaled SBT to match a scaled +# signature, so make sure that when you try such a search, it fails! +# (you *can* downsample a signature to match an SBT.) def test_search_metagenome_downsample(): with utils.TempDirectory() as location: testdata_glob = utils.get_test_data('gather/GCF*.sig') @@ -1743,6 +1785,34 @@ def test_search_metagenome_downsample_containment(): assert '12 matches; showing first 3:' in out +@utils.in_tempdir +def test_search_metagenome_downsample_index(c): + # does same search as search_metagenome_downsample_containment but + # rescales during indexing + # + # for now, this test should fail; we need to clean up some internal + # stuff before we can properly implement this! + # + testdata_glob = utils.get_test_data('gather/GCF*.sig') + testdata_sigs = glob.glob(testdata_glob) + + query_sig = utils.get_test_data('gather/combined.sig') + + # downscale during indexing, rather than during search. + c.run_sourmash('index', 'gcf_all', '-k', '21', '--scaled', '100000', + *testdata_sigs) + + assert os.path.exists(c.output('gcf_all.sbt.json')) + + c.run_sourmash('search', query_sig, 'gcf_all', '-k', '21', + '--containment') + print(c) + + assert ' 32.9% NC_003198.1 Salmonella enterica subsp. enterica serovar T...' in str(c) + assert ' 29.7% NC_003197.2 Salmonella enterica subsp. enterica serovar T...' in str(c) + assert '12 matches; showing first 3:' in str(c) + + def test_mash_csv_to_sig(): with utils.TempDirectory() as location: testdata1 = utils.get_test_data('short.fa.msh.dump')