diff --git a/src/sourmash/cli/compare.py b/src/sourmash/cli/compare.py index 4af6f38fda..81fdf1c5d1 100644 --- a/src/sourmash/cli/compare.py +++ b/src/sourmash/cli/compare.py @@ -58,9 +58,13 @@ def subparser(subparsers): '--max-containment', action='store_true', help='calculate max containment instead of similarity' ) + subparser.add_argument( + '--avg-containment', '--average-containment', action='store_true', + help='calculate average containment instead of similarity' + ) subparser.add_argument( '--estimate-ani', '--ANI', '--ani', action='store_true', - help='return ANI estimated from jaccard, containment, or max containment; see https://doi.org/10.1101/2022.01.11.475870' + help='return ANI estimated from jaccard, containment, average containment, or max containment; see https://doi.org/10.1101/2022.01.11.475870' ) subparser.add_argument( '--from-file', diff --git a/src/sourmash/commands.py b/src/sourmash/commands.py index 6d4888a82a..e148fc53a6 100644 --- a/src/sourmash/commands.py +++ b/src/sourmash/commands.py @@ -8,7 +8,7 @@ import screed from .compare import (compare_all_pairs, compare_serial_containment, - compare_serial_max_containment) + compare_serial_max_containment, compare_serial_avg_containment) from . import MinHash from .sbtmh import load_sbt_index, create_sbt_index from . import signature as sig @@ -98,16 +98,17 @@ def compare(args): sys.exit(-1) is_containment = False - if args.containment or args.max_containment: + if args.containment or args.max_containment or args.avg_containment: is_containment = True - if args.containment and args.max_containment: - notify("ERROR: cannot specify both --containment and --max-containment!") + containment_args = [args.containment, args.max_containment, args.avg_containment] + if sum(containment_args) > 1: + notify("ERROR: cannot specify more than one containment argument!") sys.exit(-1) # complain if --containment and not is_scaled if is_containment and not is_scaled: - error('must use scaled signatures with --containment and --max-containment') + error('must use scaled signatures with --containment, --max-containment, and --avg-containment') sys.exit(-1) # complain if --ani and not is_scaled @@ -123,7 +124,7 @@ def compare(args): if is_containment or return_ani: track_abundances = any(( s.minhash.track_abundance for s in siglist )) if track_abundances: - notify('NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.') + notify('NOTE: --containment, --max-containment, --avg-containment, and --estimate-ani ignore signature abundances.') # if using --scaled, downsample appropriately printed_scaled_msg = False @@ -152,6 +153,8 @@ def compare(args): similarity = compare_serial_containment(siglist, return_ani=return_ani) elif args.max_containment: similarity = compare_serial_max_containment(siglist, return_ani=return_ani) + elif args.avg_containment: + similarity = compare_serial_avg_containment(siglist, return_ani=return_ani) else: similarity = compare_all_pairs(siglist, args.ignore_abundance, n_jobs=args.processes, return_ani=return_ani) diff --git a/src/sourmash/compare.py b/src/sourmash/compare.py index 25d357b1ca..4166da3d8c 100644 --- a/src/sourmash/compare.py +++ b/src/sourmash/compare.py @@ -106,6 +106,37 @@ def compare_serial_max_containment(siglist, *, downsample=False, return_ani=Fals return containments +def compare_serial_avg_containment(siglist, *, downsample=False, return_ani=False): + """Compare all combinations of signatures and return a matrix + of avg_containments. Processes combinations serially on a single + process. Best to only use when there are few signatures. + + :param list siglist: list of signatures to compare + :param boolean downsample by scaled if True + :return: np.array similarity matrix + """ + import numpy as np + + n = len(siglist) + + # Combinations makes all unique sets of pairs, e.g. (A, B) but not (B, A) + iterator = itertools.combinations(range(n), 2) + + containments = np.ones((n, n)) + + for i, j in iterator: + if return_ani: + ani = siglist[j].avg_containment_ani(siglist[i], downsample=downsample) + if ani == None: + ani = 0.0 + containments[i][j] = containments[j][i] = ani + else: + containments[i][j] = containments[j][i] = siglist[j].avg_containment(siglist[i], + downsample=downsample) + + return containments + + def similarity_args_unpack(args, ignore_abundance, *, downsample, return_ani=False): """Helper function to unpack the arguments. Written to use in pool.imap as it can only be given one argument.""" diff --git a/tests/test_compare.py b/tests/test_compare.py index c77a3bc4b2..66c46e4386 100644 --- a/tests/test_compare.py +++ b/tests/test_compare.py @@ -7,7 +7,7 @@ import sourmash from sourmash.compare import (compare_all_pairs, compare_parallel, compare_serial, compare_serial_containment, - compare_serial_max_containment) + compare_serial_max_containment, compare_serial_avg_containment) import sourmash_tst_utils as utils @@ -130,3 +130,15 @@ def test_compare_serial_containmentANI(scaled_siglist): [0., 1., 0.97715525, 1.]]) np.testing.assert_array_almost_equal(max_containment_ANI, true_max_containment_ANI, decimal=3) + + # check avg_containment ANI + avg_containment_ANI = compare_serial_avg_containment(scaled_siglist, return_ani=True) + print(avg_containment_ANI) + + true_avg_containment_ANI = np.array( + [[1., 0., 0., 0.], + [0., 1., 0.97046289, 0.99333757], + [0., 0.97046289, 1., 0.97697067], + [0., 0.99333757, 0.97697067, 1.]]) + + np.testing.assert_array_almost_equal(avg_containment_ANI, true_avg_containment_ANI, decimal=3) diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index 708a7bcf02..b5c8855c99 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -515,6 +515,46 @@ def test_compare_max_containment(c): assert containment == mat_val, (i, j) +@utils.in_tempdir +def test_compare_avg_containment(c): + import numpy + + testdata_glob = utils.get_test_data('scaled/*.sig') + testdata_sigs = glob.glob(testdata_glob) + + c.run_sourmash('compare', '--avg-containment', '-k', '31', + '--csv', 'output.csv', *testdata_sigs) + + # load the matrix output of compare --containment + with open(c.output('output.csv'), 'rt') as fp: + r = iter(csv.reader(fp)) + headers = next(r) + + mat = numpy.zeros((len(headers), len(headers))) + for i, row in enumerate(r): + for j, val in enumerate(row): + mat[i][j] = float(val) + + print(mat) + + # load in all the input signatures + idx_to_sig = dict() + for idx, filename in enumerate(testdata_sigs): + ss = sourmash.load_one_signature(filename, ksize=31) + idx_to_sig[idx] = ss + + # check explicit containment against output of compare + for i in range(len(idx_to_sig)): + ss_i = idx_to_sig[i] + for j in range(len(idx_to_sig)): + ss_j = idx_to_sig[j] + containment = ss_j.avg_containment(ss_i) + containment = round(containment, 3) + mat_val = round(mat[i][j], 3) + + assert containment == mat_val, (i, j) + + @utils.in_tempdir def test_compare_max_containment_and_containment(c): testdata_glob = utils.get_test_data('scaled/*.sig') @@ -526,7 +566,35 @@ def test_compare_max_containment_and_containment(c): '--csv', 'output.csv', *testdata_sigs) print(c.last_result.err) - assert "ERROR: cannot specify both --containment and --max-containment!" in c.last_result.err + assert "ERROR: cannot specify more than one containment argument!" in c.last_result.err + + +@utils.in_tempdir +def test_compare_avg_containment_and_containment(c): + testdata_glob = utils.get_test_data('scaled/*.sig') + testdata_sigs = glob.glob(testdata_glob) + + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('compare', '--avg-containment', '-k', '31', + '--containment', + '--csv', 'output.csv', *testdata_sigs) + + print(c.last_result.err) + assert "ERROR: cannot specify more than one containment argument!" in c.last_result.err + + +@utils.in_tempdir +def test_compare_avg_containment_and_max_containment(c): + testdata_glob = utils.get_test_data('scaled/*.sig') + testdata_sigs = glob.glob(testdata_glob) + + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('compare', '--avg-containment', '-k', '31', + '--max-containment', + '--csv', 'output.csv', *testdata_sigs) + + print(c.last_result.err) + assert "ERROR: cannot specify more than one containment argument!" in c.last_result.err @utils.in_tempdir @@ -538,7 +606,7 @@ def test_compare_containment_abund_flatten(c): print(c.last_result.out) print(c.last_result.err) - assert 'NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.' in \ + assert 'NOTE: --containment, --max-containment, --avg-containment, and --estimate-ani ignore signature abundances.' in \ c.last_result.err @@ -551,7 +619,7 @@ def test_compare_ani_abund_flatten(c): print(c.last_result.out) print(c.last_result.err) - assert 'NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.' in \ + assert 'NOTE: --containment, --max-containment, --avg-containment, and --estimate-ani ignore signature abundances.' in \ c.last_result.err @@ -564,7 +632,7 @@ def test_compare_containment_require_scaled(c): c.run_sourmash('compare', '--containment', '-k', '31', s47, s63, fail_ok=True) - assert 'must use scaled signatures with --containment and --max-containment' in \ + assert 'must use scaled signatures with --containment, --max-containment, and --avg-containment' in \ c.last_result.err assert c.last_result.status != 0 @@ -1495,7 +1563,7 @@ def test_search_containment_s10_no_max(run): def test_search_max_containment_s10_pairwise(runtmp): - # check --containment for s10/s10-small + # check --max-containment for s10/s10-small q1 = utils.get_test_data('scaled/genome-s10.fa.gz.sig') q2 = utils.get_test_data('scaled/genome-s10-small.fa.gz.sig') @@ -5849,6 +5917,53 @@ def test_compare_max_containment_ani(c): assert containment_ani == mat_val, (i, j) +@utils.in_tempdir +def test_compare_avg_containment_ani(c): + import numpy + + sigfiles = ["2.fa.sig", "2+63.fa.sig", "47.fa.sig", "63.fa.sig"] + testdata_sigs = [utils.get_test_data(c) for c in sigfiles] + + c.run_sourmash('compare', '--avg-containment', '-k', '31', + '--estimate-ani', '--csv', 'output.csv', *testdata_sigs) + + # load the matrix output of compare --max-containment --estimate-ani + with open(c.output('output.csv'), 'rt') as fp: + r = iter(csv.reader(fp)) + headers = next(r) + + mat = numpy.zeros((len(headers), len(headers))) + for i, row in enumerate(r): + for j, val in enumerate(row): + mat[i][j] = float(val) + + print(mat) + + # load in all the input signatures + idx_to_sig = dict() + for idx, filename in enumerate(testdata_sigs): + ss = sourmash.load_one_signature(filename, ksize=31) + idx_to_sig[idx] = ss + + # check explicit avg containment against output of compare + for i in range(len(idx_to_sig)): + ss_i = idx_to_sig[i] + for j in range(len(idx_to_sig)): + mat_val = round(mat[i][j], 3) + print(mat_val) + if i == j: + assert 1 == mat_val + else: + ss_j = idx_to_sig[j] + containment_ani = ss_j.avg_containment_ani(ss_i) + if containment_ani is not None: + containment_ani = round(containment_ani, 3) + else: + containment_ani = 0.0 + + assert containment_ani == mat_val, (i, j) + + @utils.in_tempdir def test_compare_ANI_require_scaled(c): s47 = utils.get_test_data('num/47.fa.sig') @@ -5858,7 +5973,7 @@ def test_compare_ANI_require_scaled(c): with pytest.raises(SourmashCommandFailed) as exc: c.run_sourmash('compare', '--containment', '--estimate-ani', '-k', '31', s47, s63, fail_ok=True) - assert 'must use scaled signatures with --containment and --max-containment' in \ + assert 'must use scaled signatures with --containment, --max-containment, and --avg-containment' in \ c.last_result.err assert c.last_result.status != 0