Skip to content

Commit

Permalink
[MRG] use and report ANI from tax genome summarization (#2005)
Browse files Browse the repository at this point in the history
* init with tax code from #1788

* tax threshold arg from #1788

* add back sql changes from latest

* make sure we can still use tax without ANI col

* dont warn about ignoring ANI more than once

* check query stats

* fix gather test

* add sentence on ani thresh

* apply sugg from code review
  • Loading branch information
bluegenes committed Jul 24, 2022
1 parent 1bc273d commit f3a4b88
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 91 deletions.
23 changes: 14 additions & 9 deletions doc/command-line.md
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,9 @@ The sourmash `tax` or `taxonomy` commands integrate taxonomic
taxonomic rank. For example, if the gather results for a metagenome
include results for 30 different strains of a given species, we can sum
the fraction uniquely matched to each strain to obtain the fraction
uniquely matched to this species. Note that this summarization can
also take into account abundance weighting; see
[classifying signatures](classifying-signatures.md) for more
information.
uniquely matched to this species. Alternatively, taxonomic summarization
can take into account abundance weighting; see
[classifying signatures](classifying-signatures.md) for more information.

As with all reference-based analysis, results can be affected by the
completeness of the reference database. However, summarizing taxonomic
Expand Down Expand Up @@ -589,11 +588,17 @@ To produce multiple output types from the same command, add the types into the
### `sourmash tax genome` - classify a genome using `gather` results

`sourmash tax genome` reports likely classification for each query,
based on `gather` matches. By default, classification requires at least 10% of
the query to be matched. Thus, if 10% of the query was matched to a species, the
species-level classification can be reported. However, if 7% of the query was
matched to one species, and an additional 5% matched to a different species in
the same genus, the genus-level classification will be reported.
based on `gather` matches. By default, classification requires at least 10%
of the query to be matched. Thus, if 10% of the query was matched to a species,
the species-level classification can be reported. However, if 7% of the query
was matched to one species, and an additional 5% matched to a different species
in the same genus, the genus-level classification will be reported.

`sourmash tax genome` can use an ANI threshold (`--ani-threshold`) instead of a
containment threshold. This works the same way as the containment threshold
(and indeed, is using the same underlying information). Note that for DNA k-mers,
k=21 ANI is most similar to alignment-based ANI values, and ANI values should only
be compared if they were generated using the same ksize.

Optionally, `genome` can instead report classifications at a desired `rank`,
regardless of match threshold (`--rank` argument, e.g. `--rank species`).
Expand Down
10 changes: 7 additions & 3 deletions src/sourmash/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,14 @@ def range_limited_float_type(arg):
return f


def add_tax_threshold_arg(parser, default=0.1):
def add_tax_threshold_arg(parser, containment_default=0.1, ani_default=None):
parser.add_argument(
'--containment-threshold', default=default, type=range_limited_float_type,
help=f'minimum containment threshold for classification; default={default}'
'--containment-threshold', default=containment_default, type=range_limited_float_type,
help=f'minimum containment threshold for classification; default={containment_default}',
)
parser.add_argument(
'--ani-threshold', '--aai-threshold', default=ani_default, type=range_limited_float_type,
help=f'minimum ANI threshold (nucleotide gather) or AAI threshold (protein gather) for classification; default={ani_default}',
)


Expand Down
41 changes: 27 additions & 14 deletions src/sourmash/tax/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def metagenome(args):
seen_perfect = set()
for rank in sourmash.lca.taxlist(include_strain=False):
try:
summarized_gather[rank], seen_perfect = tax_utils.summarize_gather_at(rank, tax_assign, gather_results, skip_idents=idents_missed,
summarized_gather[rank], seen_perfect, _ = tax_utils.summarize_gather_at(rank, tax_assign, gather_results, skip_idents=idents_missed,
keep_full_identifiers=args.keep_full_identifiers,
keep_identifier_versions = args.keep_identifier_versions,
seen_perfect = seen_perfect)
Expand Down Expand Up @@ -179,12 +179,14 @@ def genome(args):
sys.exit(-1)

# if --rank is specified, classify to that rank
estimate_query_ani = True
if args.rank:
try:
best_at_rank, seen_perfect = tax_utils.summarize_gather_at(args.rank, tax_assign, gather_results, skip_idents=idents_missed,
best_at_rank, seen_perfect, estimate_query_ani = tax_utils.summarize_gather_at(args.rank, tax_assign, gather_results, skip_idents=idents_missed,
keep_full_identifiers=args.keep_full_identifiers,
keep_identifier_versions = args.keep_identifier_versions,
best_only=True, seen_perfect=seen_perfect)
best_only=True, seen_perfect=seen_perfect, estimate_query_ani=True)

except ValueError as exc:
error(f"ERROR: {str(exc)}")
sys.exit(-1)
Expand All @@ -194,27 +196,30 @@ def genome(args):
status = 'nomatch'
if sg.query_name in matched_queries:
continue
if sg.fraction <= args.containment_threshold:
if args.ani_threshold and sg.query_ani_at_rank < args.ani_threshold:
status="below_threshold"
notify(f"WARNING: classifying query {sg.query_name} at desired rank {args.rank} does not meet query ANI/AAI threshold {args.ani_threshold}")
elif sg.fraction <= args.containment_threshold: # should this just be less than?
status="below_threshold"
notify(f"WARNING: classifying query {sg.query_name} at desired rank {args.rank} does not meet containment threshold {args.containment_threshold}")
else:
status="match"
classif = ClassificationResult(sg.query_name, status, sg.rank, sg.fraction, sg.lineage, sg.query_md5, sg.query_filename, sg.f_weighted_at_rank, sg.bp_match_at_rank)
classif = ClassificationResult(sg.query_name, status, sg.rank, sg.fraction, sg.lineage, sg.query_md5, sg.query_filename, sg.f_weighted_at_rank, sg.bp_match_at_rank, sg.query_ani_at_rank)
classifications[args.rank].append(classif)
matched_queries.add(sg.query_name)
if "krona" in args.output_format:
lin_list = display_lineage(sg.lineage).split(';')
krona_results.append((sg.fraction, *lin_list))
else:
# classify to the match that passes the containment threshold.
# classify to the rank/match that passes the containment threshold.
# To do - do we want to store anything for this match if nothing >= containment threshold?
for rank in tax_utils.ascending_taxlist(include_strain=False):
# gets best_at_rank for all queries in this gather_csv
try:
best_at_rank, seen_perfect = tax_utils.summarize_gather_at(rank, tax_assign, gather_results, skip_idents=idents_missed,
keep_full_identifiers=args.keep_full_identifiers,
keep_identifier_versions = args.keep_identifier_versions,
best_only=True, seen_perfect=seen_perfect)
best_at_rank, seen_perfect, estimate_query_ani = tax_utils.summarize_gather_at(rank, tax_assign, gather_results, skip_idents=idents_missed,
keep_full_identifiers=args.keep_full_identifiers,
keep_identifier_versions = args.keep_identifier_versions,
best_only=True, seen_perfect=seen_perfect, estimate_query_ani=estimate_query_ani)
except ValueError as exc:
error(f"ERROR: {str(exc)}")
sys.exit(-1)
Expand All @@ -223,18 +228,26 @@ def genome(args):
status = 'nomatch'
if sg.query_name in matched_queries:
continue
if sg.fraction >= args.containment_threshold:
if sg.query_ani_at_rank is not None and args.ani_threshold and sg.query_ani_at_rank >= args.ani_threshold:
status="match"
elif sg.fraction >= args.containment_threshold:
status = "match"
classif = ClassificationResult(sg.query_name, status, sg.rank, sg.fraction, sg.lineage, sg.query_md5, sg.query_filename, sg.f_weighted_at_rank, sg.bp_match_at_rank)
if status == "match":
classif = ClassificationResult(query_name=sg.query_name, status=status, rank=sg.rank,
fraction=sg.fraction, lineage=sg.lineage,
query_md5=sg.query_md5, query_filename=sg.query_filename,
f_weighted_at_rank=sg.f_weighted_at_rank, bp_match_at_rank=sg.bp_match_at_rank,
query_ani_at_rank= sg.query_ani_at_rank)
classifications[sg.rank].append(classif)
matched_queries.add(sg.query_name)
continue
if rank == "superkingdom" and status == "nomatch":
elif rank == "superkingdom" and status == "nomatch":
status="below_threshold"
classif = ClassificationResult(query_name=sg.query_name, status=status,
rank="", fraction=0, lineage="",
query_md5=sg.query_md5, query_filename=sg.query_filename,
f_weighted_at_rank=sg.f_weighted_at_rank, bp_match_at_rank=sg.bp_match_at_rank)
f_weighted_at_rank=sg.f_weighted_at_rank, bp_match_at_rank=sg.bp_match_at_rank,
query_ani_at_rank=sg.query_ani_at_rank)
classifications[sg.rank].append(classif)

if not any([classifications, krona_results]):
Expand Down
61 changes: 45 additions & 16 deletions src/sourmash/tax/tax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sourmash import sqlite_utils
from sourmash.exceptions import IndexNotSupported
from sourmash.distance_utils import containment_to_distance

import sqlite3

Expand All @@ -24,9 +25,9 @@
from sourmash.logging import notify
from sourmash.sourmash_args import load_pathlist_from_file

QueryInfo = namedtuple("QueryInfo", "query_md5, query_filename, query_bp")
SummarizedGatherResult = namedtuple("SummarizedGatherResult", "query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank")
ClassificationResult = namedtuple("ClassificationResult", "query_name, status, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank")
QueryInfo = namedtuple("QueryInfo", "query_md5, query_filename, query_bp, query_hashes")
SummarizedGatherResult = namedtuple("SummarizedGatherResult", "query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank, query_ani_at_rank")
ClassificationResult = namedtuple("ClassificationResult", "query_name, status, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank, query_ani_at_rank")

# Essential Gather column names that must be in gather_csv to allow `tax` summarization
EssentialGatherColnames = ('query_name', 'name', 'f_unique_weighted', 'f_unique_to_query', 'unique_intersect_bp', 'remaining_bp', 'query_md5', 'query_filename')
Expand Down Expand Up @@ -188,7 +189,8 @@ def find_match_lineage(match_ident, tax_assign, *, skip_idents = [],
def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [],
keep_full_identifiers=False,
keep_identifier_versions=False, best_only=False,
seen_perfect=set()):
seen_perfect=set(),
estimate_query_ani=False):
"""
Summarize gather results at specified taxonomic rank
"""
Expand All @@ -198,6 +200,7 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [],
sum_uniq_to_query = defaultdict(lambda: defaultdict(float))
sum_uniq_bp = defaultdict(lambda: defaultdict(float))
query_info = {}
ksize, scaled, query_nhashes=None, None, None

for row in gather_results:
# get essential gather info
Expand All @@ -208,13 +211,27 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [],
query_md5 = row['query_md5']
query_filename = row['query_filename']
# get query_bp
if query_name not in query_info.keys():
query_bp = unique_intersect_bp + int(row['remaining_bp'])
if query_name not in query_info.keys(): #REMOVING THIS AFFECTS GATHER RESULTS!!! BUT query bp should always be same for same query? bug?
if "query_nhashes" in row.keys():
query_nhashes = int(row["query_nhashes"])
if "query_bp" in row.keys():
query_bp = int(row["query_bp"])
else:
query_bp = unique_intersect_bp + int(row['remaining_bp'])

# store query info
query_info[query_name] = QueryInfo(query_md5=query_md5, query_filename=query_filename, query_bp=query_bp)
query_info[query_name] = QueryInfo(query_md5=query_md5, query_filename=query_filename, query_bp=query_bp, query_hashes = query_nhashes)

if estimate_query_ani and (not ksize or not scaled): # just need to set these once. BUT, if we have these, should we check for compatibility when loading the gather file?
if "ksize" in row.keys():
ksize = int(row['ksize'])
scaled = int(row['scaled'])
else:
estimate_query_ani=False
notify("WARNING: Please run gather with sourmash >= 4.4 to estimate query ANI at rank. Continuing without ANI...")

match_ident = row['name']


# 100% match? are we looking at something in the database?
if f_unique_to_query >= 1.0 and query_name not in seen_perfect: # only want to notify once, not for each rank
ident = get_ident(match_ident,
Expand All @@ -225,16 +242,16 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [],

# get lineage for match
lineage = find_match_lineage(match_ident, tax_assign,
skip_idents=skip_idents,
keep_full_identifiers=keep_full_identifiers,
keep_identifier_versions=keep_identifier_versions)
skip_idents=skip_idents,
keep_full_identifiers=keep_full_identifiers,
keep_identifier_versions=keep_identifier_versions)
# ident was in skip_idents
if not lineage:
continue

# summarize at rank!
lineage = pop_to_rank(lineage, rank)
assert lineage[-1].rank == rank, (rank, lineage[-1])
assert lineage[-1].rank == rank, lineage[-1]
# record info
sum_uniq_to_query[query_name][lineage] += f_unique_to_query
sum_uniq_weighted[query_name][lineage] += f_uniq_weighted
Expand All @@ -246,6 +263,7 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [],
qInfo = query_info[query_name]
sumgather_items = list(lineage_weights.items())
sumgather_items.sort(key = lambda x: -x[1])
query_ani = None
if best_only:
lineage, fraction = sumgather_items[0]
if fraction > 1:
Expand All @@ -254,13 +272,18 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [],
continue
f_weighted_at_rank = sum_uniq_weighted[query_name][lineage]
bp_intersect_at_rank = sum_uniq_bp[query_name][lineage]
sres = SummarizedGatherResult(query_name, rank, fraction, lineage, qInfo.query_md5, qInfo.query_filename, f_weighted_at_rank, bp_intersect_at_rank)
if estimate_query_ani:
query_ani = containment_to_distance(fraction, ksize, scaled,
n_unique_kmers= qInfo.query_hashes, sequence_len_bp= qInfo.query_bp).ani
sres = SummarizedGatherResult(query_name, rank, fraction, lineage, qInfo.query_md5,
qInfo.query_filename, f_weighted_at_rank, bp_intersect_at_rank, query_ani)
sum_uniq_to_query_sorted.append(sres)
else:
total_f_weighted= 0.0
total_f_classified = 0.0
total_bp_classified = 0
for lineage, fraction in sumgather_items:
query_ani = None
if fraction > 1:
raise ValueError(f"The tax summary of query '{query_name}' is {fraction}, which is > 100% of the query!! This should not be possible. Please check that your input files come directly from a single gather run per query.")
elif fraction == 0:
Expand All @@ -270,19 +293,25 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [],
total_f_weighted += f_weighted_at_rank
bp_intersect_at_rank = int(sum_uniq_bp[query_name][lineage])
total_bp_classified += bp_intersect_at_rank
sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_intersect_at_rank)
if estimate_query_ani:
query_ani = containment_to_distance(fraction, ksize, scaled,
n_unique_kmers=qInfo.query_hashes, sequence_len_bp=qInfo.query_bp).ani
sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5,
query_filename, f_weighted_at_rank, bp_intersect_at_rank, query_ani)
sum_uniq_to_query_sorted.append(sres)

# record unclassified
lineage = ()
query_ani = None
fraction = 1.0 - total_f_classified
if fraction > 0:
f_weighted_at_rank = 1.0 - total_f_weighted
bp_intersect_at_rank = qInfo.query_bp - total_bp_classified
sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_intersect_at_rank)
sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5,
query_filename, f_weighted_at_rank, bp_intersect_at_rank, query_ani)
sum_uniq_to_query_sorted.append(sres)

return sum_uniq_to_query_sorted, seen_perfect
return sum_uniq_to_query_sorted, seen_perfect, estimate_query_ani


def find_missing_identities(gather_results, tax_assign):
Expand Down
Loading

0 comments on commit f3a4b88

Please sign in to comment.