Skip to content

Commit

Permalink
[MRG] add an lca gather signature decomposition approach (#390)
Browse files Browse the repository at this point in the history
* Bump version and add note pointing to stable docs
* add an initial lca gather greedy search implementation
* add output options
* don't collapse lineages
* add some docs on search, gather, and lca
* fix argparse reporting
* update tests for strain addition
* fix genus display if _ is present
* fix very challenging error to find :)
* output equivalent matches
* add test for equivalent results display and csv output
* make unassigned slightly nicer
* link to LCA databases
* catch parse error -> question if it's a valid LCA DB
* add in consistent display by name of match for both sourmash gather and lca gatherc
  • Loading branch information
ctb authored and luizirber committed Feb 23, 2018
1 parent 27dad7c commit 102b9ea
Show file tree
Hide file tree
Showing 19 changed files with 457 additions and 63 deletions.
2 changes: 1 addition & 1 deletion doc/command-line.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ using a collection of genomes with taxonomic information.

These commands use LCA databases (created with `lca index`, below, or
prepared databases such as
[genbank-k31.lca.json.gz, from the LCA tutorial](tutorials-lca.html).
[genbank-k31.lca.json.gz](databases.html)).

### `sourmash lca classify`

Expand Down
4 changes: 2 additions & 2 deletions sourmash_lib/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,15 +999,15 @@ def gather(args):
if args.output_unassigned:
if not found:
notify('nothing found - entire query signature unassigned.')
if not query.minhash.get_mins():
elif not query.minhash.get_mins():
notify('no unassigned hashes! not saving.')
else:
outname = args.output_unassigned.name
notify('saving unassigned hashes to "{}"', outname)

e = sourmash_lib.MinHash(ksize=query.minhash.ksize, n=0,
max_hash=new_max_hash)
e.add_many(query.minhash.get_mins())
e.add_many(next_query.minhash.get_mins())
sig.save_signatures([ sig.SourmashSignature(e) ],
args.output_unassigned)

Expand Down
1 change: 1 addition & 0 deletions sourmash_lib/lca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .command_classify import classify
from .command_summarize import summarize_main
from .command_rankinfo import rankinfo_main
from .command_gather import gather_main
from .__main__ import main
3 changes: 2 additions & 1 deletion sourmash_lib/lca/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import argparse

from . import classify, index, summarize_main, rankinfo_main
from . import classify, index, summarize_main, rankinfo_main, gather_main
from .command_compare_csv import compare_csv
from sourmash_lib.logging import set_quiet, error

Expand Down Expand Up @@ -33,6 +33,7 @@ def main(sysv_args):
'index': index,
'summarize': summarize_main,
'rankinfo': rankinfo_main,
'gather': gather_main,
'compare_csv': compare_csv}

parser = argparse.ArgumentParser(
Expand Down
3 changes: 1 addition & 2 deletions sourmash_lib/lca/command_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def classify(args):
"""
main single-genome classification function.
"""
p = argparse.ArgumentParser()
p = argparse.ArgumentParser(prog="sourmash lca classify")
p.add_argument('--db', nargs='+', action='append')
p.add_argument('--query', nargs='+', action='append')
p.add_argument('--threshold', type=int, default=DEFAULT_THRESHOLD)
Expand Down Expand Up @@ -109,7 +109,6 @@ def classify(args):

# load all the databases
dblist, ksize, scaled = lca_utils.load_databases(args.db, args.scaled)
notify('ksize={} scaled={}', ksize, scaled)

# find all the queries
notify('finding query signatures...')
Expand Down
2 changes: 1 addition & 1 deletion sourmash_lib/lca/command_compare_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sourmash_lib.lca.command_index import load_taxonomy_assignments

def compare_csv(args):
p = argparse.ArgumentParser()
p = argparse.ArgumentParser(prog="sourmash lca compare_csv")
p.add_argument('csv1', help='taxonomy spreadsheet output by classify')
p.add_argument('csv2', help='custom taxonomy spreadsheet')
p.add_argument('-d', '--debug', action='store_true')
Expand Down
292 changes: 292 additions & 0 deletions sourmash_lib/lca/command_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
#! /usr/bin/env python
"""
Execute a greedy search on lineages attached to hashvals in the query.
Mimics `sourmash gather` but provides taxonomic information.
"""
from __future__ import print_function, division
import sys
import argparse
import csv
from collections import Counter, defaultdict, namedtuple

import sourmash_lib
from sourmash_lib import sourmash_args
from sourmash_lib.logging import notify, error, print_results
from sourmash_lib.lca import lca_utils
from sourmash_lib.lca.lca_utils import debug, set_debug
from sourmash_lib.search import format_bp

LCAGatherResult = namedtuple('LCAGatherResult',
'intersect_bp, f_unique_to_query, f_unique_weighted, average_abund, lineage, f_match, name, n_equal_matches')


def format_lineage(lineage_tup):
"""
Pretty print lineage.
"""
# list of ranks present
present = [ l.rank for l in lineage_tup if l.name ]
d = dict(lineage_tup) # rank: value

if 'genus' in present:
genus = d['genus']
if 'strain' in present:
name = d['strain']
elif 'species' in present:
species = d['species']
if species.startswith(genus + ' ') or \
species.startswith(genus + '_'):
name = species
else:
name = '{} {}'.format(genus, species)
else:
name = '{} sp.'.format(genus)
elif len(present) < 3:
lineage_str = lca_utils.zip_lineage(lineage_tup, truncate_empty=True)
lineage_str = "; ".join(lineage_str)
name = lineage_str + ' - (no further assignment)'
elif len(present) > 1 and 'superkingdom' in present:
lowest_rank = present[-1]
name = '{}; {} {}'.format(d['superkingdom'], lowest_rank,
d[lowest_rank])
else:
lineage_str = lca_utils.zip_lineage(lineage_tup, truncate_empty=True)
lineage_str = "; ".join(lineage_str)
name = lineage_str

return name


def gather_signature(query_sig, dblist, ignore_abundance):
"""
Decompose 'query_sig' using the given list of databases.
"""
notify('loaded query: {}... (k={})', query_sig.name()[:30],
query_sig.minhash.ksize)

# extract the basic set of mins
query_mins = set(query_sig.minhash.get_mins())
n_mins = len(query_mins)

if query_sig.minhash.track_abundance and not ignore_abundance:
orig_abunds = query_sig.minhash.get_mins(with_abundance=True)
else:
if query_sig.minhash.track_abundance and ignore_abundance:
notify('** ignoring abundance')
orig_abunds = { k: 1 for k in query_mins }
sum_abunds = sum(orig_abunds.values())


# collect all mentioned lineage_ids -> md5s, from across the databases
md5_to_lineage = {}
md5_to_name = {}

x = set()
for hashval in query_mins:
for lca_db in dblist:
lineage_ids = lca_db.hashval_to_lineage_id.get(hashval, [])
for lid in lineage_ids:
md5 = lca_db.lineage_id_to_signature[lid]
x.add((lca_db, lid, md5))

for lca_db, lid, md5 in x:
md5_to_lineage[md5] = lca_db.lineage_dict[lid]
if lca_db.signature_to_name:
md5_to_name[md5] = lca_db.signature_to_name[md5]
else:
md5_to_name[md5] = ''

# now! do the gather:
while 1:
# find all of the assignments for the current set of hashes
assignments = defaultdict(set)
for hashval in query_mins:
for lca_db in dblist:
lineage_ids = lca_db.hashval_to_lineage_id.get(hashval, [])
for lid in lineage_ids:
md5 = lca_db.lineage_id_to_signature[lid]
signature_size = lca_db.lineage_id_counts[lid]
assignments[hashval].add((md5, signature_size))

# none? quit.
if not assignments:
break

# count the distinct signatures.
counts = Counter()
for hashval, md5_set in assignments.items():
for (md5, sigsize) in md5_set:
counts[(md5, sigsize)] += 1

# collect the most abundant assignments
common_iter = iter(counts.most_common())
best_list = []
(md5, sigsize), top_count = next(common_iter)

best_list.append((md5_to_name[md5], md5, sigsize))
for (md5, sigsize), count in common_iter:
if count != top_count:
break
best_list.append((md5_to_name[md5], md5, sigsize))

# sort on name and pick the top (for consistency).
best_list.sort()
_, top_md5, top_sigsize = best_list[0]

equiv_counts = len(best_list) - 1

# now, remove from query mins.
intersect_mins = set()
for hashval, md5_set in assignments.items():
if (top_md5, top_sigsize) in md5_set:
query_mins.remove(hashval)
intersect_mins.add(hashval)

# should match!
assert top_count == len(intersect_mins)

# calculate size of match (# of hashvals belonging to that sig)
match_size = top_sigsize

# construct 'result' object
intersect_bp = top_count * query_sig.minhash.scaled
f_unique_weighted = sum((orig_abunds[k] for k in intersect_mins)) \
/ sum_abunds
average_abund = sum((orig_abunds[k] for k in intersect_mins)) \
/ len(intersect_mins)
f_match = len(intersect_mins) / match_size

result = LCAGatherResult(intersect_bp = intersect_bp,
f_unique_to_query= top_count / n_mins,
f_unique_weighted=f_unique_weighted,
average_abund=average_abund,
f_match=f_match,
lineage=md5_to_lineage[top_md5],
name=md5_to_name[top_md5],
n_equal_matches=equiv_counts)

f_unassigned = len(query_mins) / n_mins
est_bp = len(query_mins) * query_sig.minhash.scaled

yield result, f_unassigned, est_bp, query_mins

## done.



def gather_main(args):
"""
Do a greedy search for the hash components of a query against an LCA db.
Here we don't actually do a least-common-ancestor search of any kind; we
do essentially the same kind of search as we do in `sourmash gather`, with
the main difference that we are implicitly combining different genomes of
identical lineages.
This takes advantage of the structure of the LCA db, where we store the
full lineage information for each known hash, as opposed to storing only
the least-common-ancestor information for it.
"""
p = argparse.ArgumentParser(prog="sourmash lca gather")
p.add_argument('query')
p.add_argument('db', nargs='+')
p.add_argument('-d', '--debug', action='store_true')
p.add_argument('-o', '--output', type=argparse.FileType('wt'),
help='output CSV containing matches to this file')
p.add_argument('--output-unassigned', type=argparse.FileType('wt'),
help='output unassigned portions of the query as a signature to this file')
p.add_argument('--ignore-abundance', action='store_true',
help='do NOT use k-mer abundances if present')
args = p.parse_args(args)

if args.debug:
set_debug(args.debug)

# load all the databases
dblist, ksize, scaled = lca_utils.load_databases(args.db, None)

# for each query, gather all the matches across databases
query_sig = sourmash_args.load_query_signature(args.query, ksize, 'DNA')
debug('classifying', query_sig.name())

# make sure we're looking at the same scaled value as database
query_sig.minhash = query_sig.minhash.downsample_scaled(scaled)

# do the classification, output results
found = []
for result, f_unassigned, est_bp, remaining_mins in gather_signature(query_sig, dblist, args.ignore_abundance):
# is this our first time through the loop? print headers, if so.
if not len(found):
print_results("")
print_results("overlap p_query p_match ")
print_results("--------- ------- --------")

# output!
pct_query = '{:.1f}%'.format(result.f_unique_to_query*100)
pct_match = '{:.1f}%'.format(result.f_match*100)
str_bp = format_bp(result.intersect_bp)
name = format_lineage(result.lineage)

equal_match_str = ""
if result.n_equal_matches:
equal_match_str = " (** {} equal matches)".format(result.n_equal_matches)

print_results('{:9} {:>6} {:>6} {}{}', str_bp, pct_query,
pct_match, name, equal_match_str)

found.append(result)

if found:
print_results('')
if f_unassigned:
print_results('{:.1f}% ({}) of hashes have no assignment.', f_unassigned*100,
format_bp(est_bp))
else:
print_results('Query is completely assigned.')
print_results('')
# nothing found.
else:
est_bp = len(query_sig.minhash.get_mins()) * query_sig.minhash.scaled
print_results('')
print_results('No assignment for est {} of sequence.',
format_bp(est_bp))
print_results('')

if not found:
sys.exit(0)

if args.output:
fieldnames = ['intersect_bp', 'f_match', 'f_unique_to_query', 'f_unique_weighted',
'average_abund', 'name', 'n_equal_matches'] + list(lca_utils.taxlist())

w = csv.DictWriter(args.output, fieldnames=fieldnames)
w.writeheader()
for result in found:
lineage = result.lineage
d = dict(result._asdict())
del d['lineage']

for (rank, value) in lineage:
d[rank] = value

w.writerow(d)

if args.output_unassigned:
if not found:
notify('nothing found - entire query signature unassigned.')
elif not remaining_mins:
notify('no unassigned hashes! not saving.')
else:
outname = args.output_unassigned.name
notify('saving unassigned hashes to "{}"', outname)

e = query_sig.minhash.copy_and_clear()
e.add_many(remaining_mins)

sourmash_lib.save_signatures([ sourmash_lib.SourmashSignature(e) ],
args.output_unassigned)


if __name__ == '__main__':
sys.exit(gather_main(sys.argv[1:]))
Loading

0 comments on commit 102b9ea

Please sign in to comment.