Skip to content

Commit

Permalink
add compare --avg-containment (#2056)
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed May 16, 2022
1 parent 851dc2b commit 3827367
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 14 deletions.
6 changes: 5 additions & 1 deletion src/sourmash/cli/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@ def subparser(subparsers):
'--max-containment', action='store_true',
help='calculate max containment instead of similarity'
)
subparser.add_argument(
'--avg-containment', '--average-containment', action='store_true',
help='calculate average containment instead of similarity'
)
subparser.add_argument(
'--estimate-ani', '--ANI', '--ani', action='store_true',
help='return ANI estimated from jaccard, containment, or max containment; see https://doi.org/10.1101/2022.01.11.475870'
help='return ANI estimated from jaccard, containment, average containment, or max containment; see https://doi.org/10.1101/2022.01.11.475870'
)
subparser.add_argument(
'--from-file',
Expand Down
15 changes: 9 additions & 6 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import screed
from .compare import (compare_all_pairs, compare_serial_containment,
compare_serial_max_containment)
compare_serial_max_containment, compare_serial_avg_containment)
from . import MinHash
from .sbtmh import load_sbt_index, create_sbt_index
from . import signature as sig
Expand Down Expand Up @@ -98,16 +98,17 @@ def compare(args):
sys.exit(-1)

is_containment = False
if args.containment or args.max_containment:
if args.containment or args.max_containment or args.avg_containment:
is_containment = True

if args.containment and args.max_containment:
notify("ERROR: cannot specify both --containment and --max-containment!")
containment_args = [args.containment, args.max_containment, args.avg_containment]
if sum(containment_args) > 1:
notify("ERROR: cannot specify more than one containment argument!")
sys.exit(-1)

# complain if --containment and not is_scaled
if is_containment and not is_scaled:
error('must use scaled signatures with --containment and --max-containment')
error('must use scaled signatures with --containment, --max-containment, and --avg-containment')
sys.exit(-1)

# complain if --ani and not is_scaled
Expand All @@ -123,7 +124,7 @@ def compare(args):
if is_containment or return_ani:
track_abundances = any(( s.minhash.track_abundance for s in siglist ))
if track_abundances:
notify('NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.')
notify('NOTE: --containment, --max-containment, --avg-containment, and --estimate-ani ignore signature abundances.')

# if using --scaled, downsample appropriately
printed_scaled_msg = False
Expand Down Expand Up @@ -152,6 +153,8 @@ def compare(args):
similarity = compare_serial_containment(siglist, return_ani=return_ani)
elif args.max_containment:
similarity = compare_serial_max_containment(siglist, return_ani=return_ani)
elif args.avg_containment:
similarity = compare_serial_avg_containment(siglist, return_ani=return_ani)
else:
similarity = compare_all_pairs(siglist, args.ignore_abundance,
n_jobs=args.processes, return_ani=return_ani)
Expand Down
31 changes: 31 additions & 0 deletions src/sourmash/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,37 @@ def compare_serial_max_containment(siglist, *, downsample=False, return_ani=Fals
return containments


def compare_serial_avg_containment(siglist, *, downsample=False, return_ani=False):
"""Compare all combinations of signatures and return a matrix
of avg_containments. Processes combinations serially on a single
process. Best to only use when there are few signatures.
:param list siglist: list of signatures to compare
:param boolean downsample by scaled if True
:return: np.array similarity matrix
"""
import numpy as np

n = len(siglist)

# Combinations makes all unique sets of pairs, e.g. (A, B) but not (B, A)
iterator = itertools.combinations(range(n), 2)

containments = np.ones((n, n))

for i, j in iterator:
if return_ani:
ani = siglist[j].avg_containment_ani(siglist[i], downsample=downsample)
if ani == None:
ani = 0.0
containments[i][j] = containments[j][i] = ani
else:
containments[i][j] = containments[j][i] = siglist[j].avg_containment(siglist[i],
downsample=downsample)

return containments


def similarity_args_unpack(args, ignore_abundance, *, downsample, return_ani=False):
"""Helper function to unpack the arguments. Written to use in pool.imap
as it can only be given one argument."""
Expand Down
14 changes: 13 additions & 1 deletion tests/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sourmash
from sourmash.compare import (compare_all_pairs, compare_parallel,
compare_serial, compare_serial_containment,
compare_serial_max_containment)
compare_serial_max_containment, compare_serial_avg_containment)
import sourmash_tst_utils as utils


Expand Down Expand Up @@ -130,3 +130,15 @@ def test_compare_serial_containmentANI(scaled_siglist):
[0., 1., 0.97715525, 1.]])

np.testing.assert_array_almost_equal(max_containment_ANI, true_max_containment_ANI, decimal=3)

# check avg_containment ANI
avg_containment_ANI = compare_serial_avg_containment(scaled_siglist, return_ani=True)
print(avg_containment_ANI)

true_avg_containment_ANI = np.array(
[[1., 0., 0., 0.],
[0., 1., 0.97046289, 0.99333757],
[0., 0.97046289, 1., 0.97697067],
[0., 0.99333757, 0.97697067, 1.]])

np.testing.assert_array_almost_equal(avg_containment_ANI, true_avg_containment_ANI, decimal=3)
127 changes: 121 additions & 6 deletions tests/test_sourmash.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,46 @@ def test_compare_max_containment(c):
assert containment == mat_val, (i, j)


@utils.in_tempdir
def test_compare_avg_containment(c):
import numpy

testdata_glob = utils.get_test_data('scaled/*.sig')
testdata_sigs = glob.glob(testdata_glob)

c.run_sourmash('compare', '--avg-containment', '-k', '31',
'--csv', 'output.csv', *testdata_sigs)

# load the matrix output of compare --containment
with open(c.output('output.csv'), 'rt') as fp:
r = iter(csv.reader(fp))
headers = next(r)

mat = numpy.zeros((len(headers), len(headers)))
for i, row in enumerate(r):
for j, val in enumerate(row):
mat[i][j] = float(val)

print(mat)

# load in all the input signatures
idx_to_sig = dict()
for idx, filename in enumerate(testdata_sigs):
ss = sourmash.load_one_signature(filename, ksize=31)
idx_to_sig[idx] = ss

# check explicit containment against output of compare
for i in range(len(idx_to_sig)):
ss_i = idx_to_sig[i]
for j in range(len(idx_to_sig)):
ss_j = idx_to_sig[j]
containment = ss_j.avg_containment(ss_i)
containment = round(containment, 3)
mat_val = round(mat[i][j], 3)

assert containment == mat_val, (i, j)


@utils.in_tempdir
def test_compare_max_containment_and_containment(c):
testdata_glob = utils.get_test_data('scaled/*.sig')
Expand All @@ -526,7 +566,35 @@ def test_compare_max_containment_and_containment(c):
'--csv', 'output.csv', *testdata_sigs)

print(c.last_result.err)
assert "ERROR: cannot specify both --containment and --max-containment!" in c.last_result.err
assert "ERROR: cannot specify more than one containment argument!" in c.last_result.err


@utils.in_tempdir
def test_compare_avg_containment_and_containment(c):
testdata_glob = utils.get_test_data('scaled/*.sig')
testdata_sigs = glob.glob(testdata_glob)

with pytest.raises(SourmashCommandFailed) as exc:
c.run_sourmash('compare', '--avg-containment', '-k', '31',
'--containment',
'--csv', 'output.csv', *testdata_sigs)

print(c.last_result.err)
assert "ERROR: cannot specify more than one containment argument!" in c.last_result.err


@utils.in_tempdir
def test_compare_avg_containment_and_max_containment(c):
testdata_glob = utils.get_test_data('scaled/*.sig')
testdata_sigs = glob.glob(testdata_glob)

with pytest.raises(SourmashCommandFailed) as exc:
c.run_sourmash('compare', '--avg-containment', '-k', '31',
'--max-containment',
'--csv', 'output.csv', *testdata_sigs)

print(c.last_result.err)
assert "ERROR: cannot specify more than one containment argument!" in c.last_result.err


@utils.in_tempdir
Expand All @@ -538,7 +606,7 @@ def test_compare_containment_abund_flatten(c):
print(c.last_result.out)
print(c.last_result.err)

assert 'NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.' in \
assert 'NOTE: --containment, --max-containment, --avg-containment, and --estimate-ani ignore signature abundances.' in \
c.last_result.err


Expand All @@ -551,7 +619,7 @@ def test_compare_ani_abund_flatten(c):
print(c.last_result.out)
print(c.last_result.err)

assert 'NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.' in \
assert 'NOTE: --containment, --max-containment, --avg-containment, and --estimate-ani ignore signature abundances.' in \
c.last_result.err


Expand All @@ -564,7 +632,7 @@ def test_compare_containment_require_scaled(c):
c.run_sourmash('compare', '--containment', '-k', '31', s47, s63,
fail_ok=True)

assert 'must use scaled signatures with --containment and --max-containment' in \
assert 'must use scaled signatures with --containment, --max-containment, and --avg-containment' in \
c.last_result.err
assert c.last_result.status != 0

Expand Down Expand Up @@ -1495,7 +1563,7 @@ def test_search_containment_s10_no_max(run):


def test_search_max_containment_s10_pairwise(runtmp):
# check --containment for s10/s10-small
# check --max-containment for s10/s10-small
q1 = utils.get_test_data('scaled/genome-s10.fa.gz.sig')
q2 = utils.get_test_data('scaled/genome-s10-small.fa.gz.sig')

Expand Down Expand Up @@ -5849,6 +5917,53 @@ def test_compare_max_containment_ani(c):
assert containment_ani == mat_val, (i, j)


@utils.in_tempdir
def test_compare_avg_containment_ani(c):
import numpy

sigfiles = ["2.fa.sig", "2+63.fa.sig", "47.fa.sig", "63.fa.sig"]
testdata_sigs = [utils.get_test_data(c) for c in sigfiles]

c.run_sourmash('compare', '--avg-containment', '-k', '31',
'--estimate-ani', '--csv', 'output.csv', *testdata_sigs)

# load the matrix output of compare --max-containment --estimate-ani
with open(c.output('output.csv'), 'rt') as fp:
r = iter(csv.reader(fp))
headers = next(r)

mat = numpy.zeros((len(headers), len(headers)))
for i, row in enumerate(r):
for j, val in enumerate(row):
mat[i][j] = float(val)

print(mat)

# load in all the input signatures
idx_to_sig = dict()
for idx, filename in enumerate(testdata_sigs):
ss = sourmash.load_one_signature(filename, ksize=31)
idx_to_sig[idx] = ss

# check explicit avg containment against output of compare
for i in range(len(idx_to_sig)):
ss_i = idx_to_sig[i]
for j in range(len(idx_to_sig)):
mat_val = round(mat[i][j], 3)
print(mat_val)
if i == j:
assert 1 == mat_val
else:
ss_j = idx_to_sig[j]
containment_ani = ss_j.avg_containment_ani(ss_i)
if containment_ani is not None:
containment_ani = round(containment_ani, 3)
else:
containment_ani = 0.0

assert containment_ani == mat_val, (i, j)


@utils.in_tempdir
def test_compare_ANI_require_scaled(c):
s47 = utils.get_test_data('num/47.fa.sig')
Expand All @@ -5858,7 +5973,7 @@ def test_compare_ANI_require_scaled(c):
with pytest.raises(SourmashCommandFailed) as exc:
c.run_sourmash('compare', '--containment', '--estimate-ani', '-k', '31', s47, s63,
fail_ok=True)
assert 'must use scaled signatures with --containment and --max-containment' in \
assert 'must use scaled signatures with --containment, --max-containment, and --avg-containment' in \
c.last_result.err
assert c.last_result.status != 0

Expand Down

0 comments on commit 3827367

Please sign in to comment.