diff --git a/src/sourmash/lca/command_summarize.py b/src/sourmash/lca/command_summarize.py index 8ff3e99070..d8acaea867 100644 --- a/src/sourmash/lca/command_summarize.py +++ b/src/sourmash/lca/command_summarize.py @@ -116,7 +116,8 @@ def output_results(lineage_counts, total_counts, filename=None, sig=None): print_results('{:5} {:>5} {} {}:{} {}'.format(p, count, lineage, filename, sig.md5sum()[:8], sig)) -def output_csv(lineage_counts, csv_fp, filename, sig, write_header=True): +def output_csv(lineage_counts, total_counts, csv_fp, filename, sig, + write_header=True): """\ Output results in CSV. """ @@ -124,13 +125,13 @@ def output_csv(lineage_counts, csv_fp, filename, sig, write_header=True): w = csv.writer(csv_fp) if write_header: headers = ['count'] + list(lca_utils.taxlist()) - headers += ['filename', 'sig_name', 'sig_md5'] + headers += ['filename', 'sig_name', 'sig_md5', 'total_counts'] w.writerow(headers) for (lineage, count) in lineage_counts.items(): debug('lineage:', lineage) row = [count] + lca_utils.zip_lineage(lineage, truncate_empty=False) - row += [filename, sig.name, sig.md5sum()] + row += [filename, sig.name, sig.md5sum(), total_counts] w.writerow(row) @@ -198,7 +199,7 @@ def summarize_main(args): filename=filename, sig=sig) if csv_fp: - output_csv(lineage_counts, csv_fp, filename, sig, + output_csv(lineage_counts, total, csv_fp, filename, sig, write_header=write_header) write_header = False finally: diff --git a/tests/test_lca.py b/tests/test_lca.py index d83eff7234..2467914b13 100644 --- a/tests/test_lca.py +++ b/tests/test_lca.py @@ -1693,10 +1693,47 @@ def test_single_summarize_to_output_check_filename(runtmp): outdata = open(runtmp.output('output.txt'), 'rt').read() assert 'loaded 1 signatures from 1 files total.' in runtmp.last_result.err - assert 'count,superkingdom,phylum,class,order,family,genus,species,strain,filename,sig_name,sig_md5\n' in outdata - assert '200,Bacteria,Proteobacteria,Gammaproteobacteria,Alteromonadales,Alteromonadaceae,Alteromonas,Alteromonas_macleodii,,'+os.path.join(in_dir, 'q.sig')+',TARA_ASE_MAG_00031,5b438c6c858cdaf9e9b05a207fa3f9f0' in outdata + assert 'count,superkingdom,phylum,class,order,family,genus,species,strain,filename,sig_name,sig_md5,total_counts\n' in outdata + assert '200,Bacteria,Proteobacteria,Gammaproteobacteria,Alteromonadales,Alteromonadaceae,Alteromonas,Alteromonas_macleodii,,'+os.path.join(in_dir, 'q.sig')+',TARA_ASE_MAG_00031,5b438c6c858cdaf9e9b05a207fa3f9f0,200.0\n' in outdata + print(outdata) +def test_summarize_unknown_hashes_to_output_check_total_counts(runtmp): + taxcsv = utils.get_test_data('lca-root/tax.csv') + input_sig1 = utils.get_test_data('lca-root/TARA_MED_MAG_00029.fa.sig') + input_sig2 = utils.get_test_data('lca-root/TOBG_MED-875.fna.gz.sig') + lca_db = runtmp.output('lca-root.lca.json') + + cmd = ['lca', 'index', taxcsv, lca_db, input_sig2] + runtmp.sourmash(*cmd) + + print(cmd) + print(runtmp.last_result.out) + print(runtmp.last_result.err) + + assert os.path.exists(lca_db) + + assert '1 identifiers used out of 2 distinct identifiers in spreadsheet.' in runtmp.last_result.err + + cmd = ['lca', 'summarize', '--db', lca_db, '--query', input_sig1, + '-o', 'out.csv'] + runtmp.sourmash(*cmd) + + print(cmd) + print(runtmp.last_result.out) + print(runtmp.last_result.err) + + assert '(root)' not in runtmp.last_result.out + assert '11.5% 27 Archaea;Euryarcheoata;unassigned;unassigned;novelFamily_I' in runtmp.last_result.out + + with open(runtmp.output('out.csv'), newline="") as fp: + r = csv.DictReader(fp) + rows = list(r) + pairs = [ (row['count'], row['total_counts']) for row in rows ] + pairs = [ (float(x), float(y)) for x, y in pairs ] + pairs = set(pairs) + + assert pairs == { (27.0, 234.0) } def test_single_summarize_scaled(runtmp):