Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EXP] add sourmash distance estimation #1788

Closed
wants to merge 46 commits into from
Closed
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f728db1
init dist utils
bluegenes Jan 13, 2022
70d2ae2
add confidence tests
bluegenes Jan 13, 2022
6c061a0
optionally return sim/ident instead
bluegenes Jan 13, 2022
6da758e
add ani estimation to signature
bluegenes Jan 13, 2022
bd6485d
return confidence interval in addition to point estimate
bluegenes Jan 13, 2022
2177843
return ANI from compare
bluegenes Jan 14, 2022
4d61a38
add containment ANI tests to test_compare
bluegenes Jan 14, 2022
0e351b8
init compare ANI tests
bluegenes Jan 14, 2022
66a570f
add option
bluegenes Jan 14, 2022
530ab76
ani in prefetch and gather
bluegenes Jan 14, 2022
361bb4d
test ani from precalculated score
bluegenes Jan 14, 2022
b012f5a
return_ANI --> return_ani
bluegenes Jan 14, 2022
48b5ccf
remaining return_ANI changes
bluegenes Jan 14, 2022
5da01cc
cleanup unused fn
bluegenes Jan 14, 2022
f452d9e
spacing
bluegenes Jan 14, 2022
64b8b96
fix mc kmers
bluegenes Jan 17, 2022
70eb0ee
add corresponding mh methods
bluegenes Jan 17, 2022
7819e3c
merge in latest
bluegenes Jan 17, 2022
8ec9e10
round kmers for jaccard
bluegenes Jan 17, 2022
357aa58
allow confidence tuning
bluegenes Jan 18, 2022
af036b8
enable bp --> ANI
bluegenes Jan 19, 2022
71a859b
fix c to dist in search
bluegenes Jan 19, 2022
d284a24
optionally estimate query ani while summarizing gather results
bluegenes Jan 19, 2022
15db152
addl cols in gatherresult
bluegenes Jan 19, 2022
ef4dee4
addl prefetch cols; test ani in summarize_gather_at
bluegenes Jan 20, 2022
581b404
add tests for ani_threshold
bluegenes Jan 20, 2022
0adc5f2
fix typo
bluegenes Jan 20, 2022
041edf1
addl tests
bluegenes Jan 20, 2022
75402f3
report CI in prefetch, gather bc I wants them for testing
bluegenes Jan 20, 2022
ae97dc0
estimate ANI during search
bluegenes Jan 25, 2022
ab5f2fe
search_scaled
bluegenes Jan 26, 2022
3b9bfe7
wade into downsampling
bluegenes Jan 26, 2022
fdec49b
dist comment
bluegenes Jan 26, 2022
8e268f0
handle downsampling in minhash, signature ANI fns
bluegenes Feb 2, 2022
4aed030
Merge branch 'latest' into add-dist-est
bluegenes Feb 2, 2022
845abd4
Merge branch 'latest' into add-dist-est
bluegenes Feb 8, 2022
727ee7f
whoops, fix rogue test
bluegenes Feb 9, 2022
962469a
Merge branch 'latest' into add-dist-est
bluegenes Feb 9, 2022
bdc67a6
Merge branch 'latest' into add-dist-est
bluegenes Mar 8, 2022
e579e3a
Add dist est (#1860)
mahmudhera Mar 12, 2022
b02601f
Merge branch 'latest' into add-dist-est
bluegenes Mar 12, 2022
a349d75
enable just point estimate for distance_to_identity
bluegenes Mar 12, 2022
58f2a07
save in progress changes
bluegenes Mar 13, 2022
2f4dc37
upd
bluegenes Apr 1, 2022
9495198
Merge branch 'latest' into add-dist-est
bluegenes Apr 2, 2022
e8c9997
Merge branch 'latest' into add-dist-est
bluegenes Apr 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sourmash/cli/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def subparser(subparsers):
'--max-containment', action='store_true',
help='calculate max containment instead of similarity'
)
subparser.add_argument(
'--estimate-ani', '--estimate-ANI', action='store_true',
help='return ANI estimated from jaccard, containment, or max containment; see https://doi.org/10.1101/2022.01.11.475870'
)
subparser.add_argument(
'--from-file',
help='a text file containing a list of files to load signatures from'
Expand Down
10 changes: 7 additions & 3 deletions src/sourmash/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,14 @@ def range_limited_float_type(arg):
return f


def add_tax_threshold_arg(parser, default=0.1):
def add_tax_threshold_arg(parser, containment_default=0.1, ani_default=None):
parser.add_argument(
'--containment-threshold', default=default, type=range_limited_float_type,
help=f'minimum containment threshold for classification; default={default}'
'--containment-threshold', default=containment_default, type=range_limited_float_type,
help=f'minimum containment threshold for classification; default={containment_default}',
)
parser.add_argument(
'--ani-threshold', '--aai-threshold', default=ani_default, type=range_limited_float_type,
help=f'minimum ANI threshold (nucleotide gather) or AAI threshold (protein gather) for classification; default={ani_default}',
)


Expand Down
53 changes: 20 additions & 33 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .logging import notify, error, print_results, set_quiet
from .sourmash_args import (FileOutput, FileOutputCSV,
SaveSignaturesToLocation)
from .search import prefetch_database, PrefetchResult, calculate_prefetch_info
from .search import SearchResult, prefetch_database, PrefetchResult, GatherResult, calculate_prefetch_info
from .index import LazyLinearIndex

WATERMARK_SIZE = 10000
Expand Down Expand Up @@ -108,11 +108,20 @@ def compare(args):
error('must use scaled signatures with --containment and --max-containment')
sys.exit(-1)

# complain if --ani and not is_scaled
return_ani = False
if args.estimate_ani:
return_ani = True

if return_ani and not is_scaled:
error('must use scaled signatures with --estimate-ani')
sys.exit(-1)

# notify about implicit --ignore-abundance:
if is_containment:
if is_containment or return_ani:
track_abundances = any(( s.minhash.track_abundance for s in siglist ))
if track_abundances:
notify('NOTE: --containment and --max-containment ignore signature abundances.')
notify('NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.')

# if using --scaled, downsample appropriately
printed_scaled_msg = False
Expand All @@ -138,12 +147,12 @@ def compare(args):

labeltext = [str(item) for item in siglist]
if args.containment:
similarity = compare_serial_containment(siglist)
similarity = compare_serial_containment(siglist, return_ani=return_ani)
elif args.max_containment:
similarity = compare_serial_max_containment(siglist)
similarity = compare_serial_max_containment(siglist, return_ani=return_ani)
else:
similarity = compare_all_pairs(siglist, args.ignore_abundance,
n_jobs=args.processes)
n_jobs=args.processes, return_ani=return_ani)

if len(siglist) < 30:
for i, E in enumerate(siglist):
Expand Down Expand Up @@ -525,8 +534,7 @@ def search(args):
notify("** reporting only one match because --best-only was set")

if args.output:
fieldnames = ['similarity', 'name', 'filename', 'md5',
'query_filename', 'query_name', 'query_md5']
fieldnames = SearchResult._fields

with FileOutputCSV(args.output) as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)
Expand Down Expand Up @@ -679,11 +687,7 @@ def gather(args):
prefetch_csvout_fp = None
prefetch_csvout_w = None
if args.save_prefetch_csv:
fieldnames = ['intersect_bp', 'jaccard',
'max_containment', 'f_query_match', 'f_match_query',
'match_filename', 'match_name', 'match_md5', 'match_bp',
'query_filename', 'query_name', 'query_md5', 'query_bp']

fieldnames = PrefetchResult._fields
prefetch_csvout_fp = FileOutput(args.save_prefetch_csv, 'wt').open()
prefetch_csvout_w = csv.DictWriter(prefetch_csvout_fp, fieldnames=fieldnames)
prefetch_csvout_w.writeheader()
Expand Down Expand Up @@ -798,13 +802,7 @@ def gather(args):

# save CSV?
if found and args.output:
fieldnames = ['intersect_bp', 'f_orig_query', 'f_match',
'f_unique_to_query', 'f_unique_weighted',
'average_abund', 'median_abund', 'std_abund', 'name',
'filename', 'md5', 'f_match_orig', 'unique_intersect_bp',
'gather_result_rank', 'remaining_bp',
'query_filename', 'query_name', 'query_md5', 'query_bp']

fieldnames = GatherResult._fields
ctb marked this conversation as resolved.
Show resolved Hide resolved
with FileOutputCSV(args.output) as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)
w.writeheader()
Expand Down Expand Up @@ -971,14 +969,7 @@ def multigather(args):

output_base = os.path.basename(query_filename)
output_csv = output_base + '.csv'

fieldnames = ['intersect_bp', 'f_orig_query', 'f_match',
'f_unique_to_query', 'f_unique_weighted',
'average_abund', 'median_abund', 'std_abund', 'name',
'filename', 'md5', 'f_match_orig',
'unique_intersect_bp', 'gather_result_rank',
'remaining_bp', 'query_filename', 'query_name',
'query_md5', 'query_bp']
fieldnames = GatherResult._fields
with FileOutputCSV(output_csv) as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)
w.writeheader()
Expand Down Expand Up @@ -1177,11 +1168,7 @@ def prefetch(args):
csvout_fp = None
csvout_w = None
if args.output:
fieldnames = ['intersect_bp', 'jaccard',
'max_containment', 'f_query_match', 'f_match_query',
'match_filename', 'match_name', 'match_md5', 'match_bp',
'query_filename', 'query_name', 'query_md5', 'query_bp']

fieldnames = PrefetchResult._fields
csvout_fp = FileOutput(args.output, 'wt').open()
csvout_w = csv.DictWriter(csvout_fp, fieldnames=fieldnames)
csvout_w.writeheader()
Expand Down
46 changes: 31 additions & 15 deletions src/sourmash/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sourmash.np_utils import to_memmap


def compare_serial(siglist, ignore_abundance, downsample=False):
def compare_serial(siglist, ignore_abundance, downsample=False, return_ani=False):
"""Compare all combinations of signatures and return a matrix
of similarities. Processes combinations serially on a single
process. Best to use when there is few signatures.
Expand All @@ -34,12 +34,15 @@ def compare_serial(siglist, ignore_abundance, downsample=False):
similarities = np.ones((n, n))

for i, j in iterator:
similarities[i][j] = similarities[j][i] = siglist[i].similarity(siglist[j], ignore_abundance, downsample)
if return_ani:
similarities[i][j] = similarities[j][i] = siglist[i].jaccard_ani(siglist[j], downsample)[0]
else:
similarities[i][j] = similarities[j][i] = siglist[i].similarity(siglist[j], ignore_abundance, downsample)

return similarities


def compare_serial_containment(siglist, downsample=False):
def compare_serial_containment(siglist, downsample=False, return_ani=False):
"""Compare all combinations of signatures and return a matrix
of containments. Processes combinations serially on a single
process. Best to only use when there are few signatures.
Expand All @@ -55,13 +58,17 @@ def compare_serial_containment(siglist, downsample=False):
containments = np.ones((n, n))
for i in range(n):
for j in range(n):
containments[i][j] = siglist[j].contained_by(siglist[i],
if return_ani:
containments[i][j] = siglist[j].containment_ani(siglist[i],
downsample=downsample)[0]
else:
containments[i][j] = siglist[j].contained_by(siglist[i],
downsample=downsample)

return containments


def compare_serial_max_containment(siglist, downsample=False):
def compare_serial_max_containment(siglist, downsample=False, return_ani=False):
"""Compare all combinations of signatures and return a matrix
of max_containments. Processes combinations serially on a single
process. Best to only use when there are few signatures.
Expand All @@ -77,22 +84,30 @@ def compare_serial_max_containment(siglist, downsample=False):
containments = np.ones((n, n))
for i in range(n):
for j in range(n):
containments[i][j] = siglist[j].max_containment(siglist[i],
if return_ani:
containments[i][j] = siglist[j].max_containment_ani(siglist[i],
downsample=downsample)[0]
else:
containments[i][j] = siglist[j].max_containment(siglist[i],
downsample=downsample)

return containments


def similarity_args_unpack(args, ignore_abundance, downsample):
def similarity_args_unpack(args, ignore_abundance, downsample, return_ani=False):
"""Helper function to unpack the arguments. Written to use in pool.imap
as it can only be given one argument."""
sig1, sig2 = args
return sig1.similarity(sig2,
if return_ani:
return sig1.jaccard_ani(sig2,
downsample=downsample)[0]
else:
return sig1.similarity(sig2,
ignore_abundance=ignore_abundance,
downsample=downsample)


def get_similarities_at_index(index, ignore_abundance, downsample, siglist):
def get_similarities_at_index(index, ignore_abundance, downsample, siglist, return_ani=False):
"""Returns similarities of all the combinations of signature at index in
the siglist with the rest of the indices starting at index + 1. Doesn't
redundantly calculate signatures with all the other indices prior to
Expand All @@ -114,14 +129,14 @@ def get_similarities_at_index(index, ignore_abundance, downsample, siglist):
sig_iterator = itertools.product([siglist[index]], siglist[index + 1:])
func = partial(similarity_args_unpack,
ignore_abundance=ignore_abundance,
downsample=downsample)
downsample=downsample, return_ani=return_ani)
similarity_list = list(map(func, sig_iterator))
notify(
f"comparison for index {index} done in {time.time() - startt:.5f} seconds", end='\r')
return similarity_list


def compare_parallel(siglist, ignore_abundance, downsample, n_jobs):
def compare_parallel(siglist, ignore_abundance, downsample, n_jobs, return_ani=False):
"""Compare all combinations of signatures and return a matrix
of similarities. Processes combinations parallely on number of processes
given by n_jobs
Expand Down Expand Up @@ -163,7 +178,8 @@ def compare_parallel(siglist, ignore_abundance, downsample, n_jobs):
get_similarities_at_index,
siglist=siglist,
ignore_abundance=ignore_abundance,
downsample=downsample)
downsample=downsample,
return_ani=return_ani)
notify("Created similarity func")

# Initialize multiprocess.pool
Expand Down Expand Up @@ -198,7 +214,7 @@ def compare_parallel(siglist, ignore_abundance, downsample, n_jobs):
return np.memmap(filename, dtype=np.float64, shape=(length_siglist, length_siglist))


def compare_all_pairs(siglist, ignore_abundance, downsample=False, n_jobs=None):
def compare_all_pairs(siglist, ignore_abundance, downsample=False, n_jobs=None, return_ani=False):
"""Compare all combinations of signatures and return a matrix
of similarities. Processes combinations either serially or
based on parallely on number of processes given by n_jobs
Expand All @@ -216,7 +232,7 @@ def compare_all_pairs(siglist, ignore_abundance, downsample=False, n_jobs=None):
:return: np.array similarity matrix
"""
if n_jobs is None or n_jobs == 1:
similarities = compare_serial(siglist, ignore_abundance, downsample)
similarities = compare_serial(siglist, ignore_abundance, downsample, return_ani)
else:
similarities = compare_parallel(siglist, ignore_abundance, downsample, n_jobs)
similarities = compare_parallel(siglist, ignore_abundance, downsample, n_jobs, return_ani)
return similarities
Loading