diff --git a/doc/command-line.md b/doc/command-line.md index 4697797e1b..14e875a9d6 100644 --- a/doc/command-line.md +++ b/doc/command-line.md @@ -224,7 +224,6 @@ sourmash compare [ ... ] Options: * `--output ` -- save the output matrix to this file, as a numpy binary matrix. -* `--csv ` -- save the output matrix to this file in CSV format. * `--distance-matrix` -- create and output a distance matrix, instead of a similarity matrix. * `--ksize ` -- do the comparisons at this k-mer size. * `--containment` -- calculate containment instead of similarity; `C(i, j) = size(i intersection j) / size(i)` @@ -233,6 +232,7 @@ Options: * `--ignore-abundance` -- ignore abundances in signatures. * `--picklist ::` -- select a subset of signatures with [a picklist](#using-picklists-to-subset-large-collections-of-signatures) * `--csv ` -- save the output matrix in CSV format. +* `--labels-to ` -- create a CSV file (spreadsheet) that can be passed in to `sourmash plot` with `--labels-from` in order to customize the labels. **Note:** compare by default produces a symmetric similarity matrix that can be used for clustering in downstream tasks. With `--containment`, diff --git a/src/sourmash/cli/compare.py b/src/sourmash/cli/compare.py index 74da5bd837..45844aaa1d 100644 --- a/src/sourmash/cli/compare.py +++ b/src/sourmash/cli/compare.py @@ -94,6 +94,11 @@ def subparser(subparsers): metavar="F", help="write matrix to specified file in CSV format (with column " "headers)", ) + subparser.add_argument( + "--labels-to", + "--labels-save", + help="a CSV file containing label information", + ) subparser.add_argument( "-p", "--processes", diff --git a/src/sourmash/cli/plot.py b/src/sourmash/cli/plot.py index 718a5c8528..dd4726c365 100644 --- a/src/sourmash/cli/plot.py +++ b/src/sourmash/cli/plot.py @@ -72,6 +72,11 @@ def subparser(subparsers): help="write clustered matrix and labels out in CSV format (with column" " headers) to this file", ) + subparser.add_argument( + "--labels-from", + "--labels-load", + help="a CSV file containing label information to use on plot; implies --labels", + ) def main(args): diff --git a/src/sourmash/commands.py b/src/sourmash/commands.py index e2d1a09a50..920693f9df 100644 --- a/src/sourmash/commands.py +++ b/src/sourmash/commands.py @@ -71,13 +71,19 @@ def compare(args): notify( f"\nwarning: no signatures loaded at given ksize/molecule type/picklist from {filename}" ) - siglist.extend(loaded) - # track ksizes/moltypes + # add to siglist; track ksizes/moltypes + s = None for s in loaded: + siglist.append((s, filename)) ksizes.add(s.minhash.ksize) moltypes.add(sourmash_args.get_moltype(s)) + if s is None: + notify( + f"\nwarning: no signatures loaded at given ksize/molecule type/picklist from {filename}" + ) + # error out while loading if we have more than one ksize/moltype if len(ksizes) > 1 or len(moltypes) > 1: break @@ -105,7 +111,7 @@ def compare(args): # check to make sure they're potentially compatible - either using # scaled, or not. - scaled_sigs = [s.minhash.scaled for s in siglist] + scaled_sigs = [s.minhash.scaled for (s, _) in siglist] is_scaled = all(scaled_sigs) is_scaled_2 = any(scaled_sigs) @@ -145,16 +151,20 @@ def compare(args): # notify about implicit --ignore-abundance: if is_containment or return_ani: - track_abundances = any(s.minhash.track_abundance for s in siglist) + track_abundances = any(s.minhash.track_abundance for s, _ in siglist) if track_abundances: notify( "NOTE: --containment, --max-containment, --avg-containment, and --estimate-ani ignore signature abundances." ) + # CTB: note, up to this point, we could do everything with manifests + # w/o actually loading any signatures. I'm not sure the manifest + # API allows it tho. + # if using scaled sketches or --scaled, downsample to common max scaled. printed_scaled_msg = False if is_scaled: - max_scaled = max(s.minhash.scaled for s in siglist) + max_scaled = max(s.minhash.scaled for s, _ in siglist) if args.scaled: args.scaled = int(args.scaled) @@ -166,7 +176,7 @@ def compare(args): notify(f"WARNING: continuing with scaled value of {max_scaled}.") new_siglist = [] - for s in siglist: + for s, filename in siglist: if not size_may_be_inaccurate and not s.minhash.size_is_accurate(): size_may_be_inaccurate = True if s.minhash.scaled != max_scaled: @@ -177,9 +187,9 @@ def compare(args): printed_scaled_msg = True with s.update() as s: s.minhash = s.minhash.downsample(scaled=max_scaled) - new_siglist.append(s) + new_siglist.append((s, filename)) else: - new_siglist.append(s) + new_siglist.append((s, filename)) siglist = new_siglist elif args.scaled is not None: error("ERROR: cannot specify --scaled with non-scaled signatures.") @@ -196,16 +206,20 @@ def compare(args): # do all-by-all calculation - labeltext = [str(item) for item in siglist] + labeltext = [str(ss) for ss, _ in siglist] + sigsonly = [ss for ss, _ in siglist] if args.containment: - similarity = compare_serial_containment(siglist, return_ani=return_ani) + similarity = compare_serial_containment(sigsonly, return_ani=return_ani) elif args.max_containment: - similarity = compare_serial_max_containment(siglist, return_ani=return_ani) + similarity = compare_serial_max_containment(sigsonly, return_ani=return_ani) elif args.avg_containment: - similarity = compare_serial_avg_containment(siglist, return_ani=return_ani) + similarity = compare_serial_avg_containment(sigsonly, return_ani=return_ani) else: similarity = compare_all_pairs( - siglist, args.ignore_abundance, n_jobs=args.processes, return_ani=return_ani + sigsonly, + args.ignore_abundance, + n_jobs=args.processes, + return_ani=return_ani, ) # if distance matrix desired, switch to 1-similarity @@ -215,7 +229,7 @@ def compare(args): matrix = similarity if len(siglist) < 30: - for i, ss in enumerate(siglist): + for i, (ss, filename) in enumerate(siglist): # for small matrices, pretty-print some output name_num = f"{i}-{str(ss)}" if len(name_num) > 20: @@ -246,6 +260,25 @@ def compare(args): with open(args.output, "wb") as fp: numpy.save(fp, matrix) + # output labels information via --labels-to? + if args.labels_to: + labeloutname = args.labels_to + notify(f"saving labels to: {labeloutname}") + with sourmash_args.FileOutputCSV(labeloutname) as fp: + w = csv.writer(fp) + w.writerow( + ["sort_order", "md5", "label", "name", "filename", "signature_file"] + ) + + for n, (ss, location) in enumerate(siglist): + md5 = ss.md5sum() + sigfile = location + label = str(ss) + name = ss.name + filename = ss.filename + + w.writerow([str(n + 1), md5, label, name, filename, sigfile]) + # output CSV? if args.csv: with FileOutputCSV(args.csv) as csv_fp: @@ -289,7 +322,10 @@ def plot(args): notify("...got {} x {} matrix.", *D.shape) # see sourmash#2790 for details :) - if args.labeltext or args.labels: + if args.labeltext or args.labels or args.labels_from: + if args.labeltext and args.labels_from: + notify("ERROR: cannot supply both --labeltext and --labels-from") + sys.exit(-1) display_labels = True args.labels = True # override => labels always true elif args.labels is None and not args.indices: @@ -303,13 +339,24 @@ def plot(args): else: display_labels = False - if args.labels: + if args.labels_from: + labelfilename = args.labels_from + notify(f"loading labels from CSV file '{labelfilename}'") + + labeltext = [] + with sourmash_args.FileInputCSV(labelfilename) as r: + for row in r: + order, label = row["sort_order"], row["label"] + labeltext.append((int(order), label)) + labeltext.sort() + labeltext = [t[1] for t in labeltext] + elif args.labels: if args.labeltext: labelfilename = args.labeltext else: labelfilename = D_filename + ".labels.txt" - notify(f"loading labels from {labelfilename}") + notify(f"loading labels from text file '{labelfilename}'") with open(labelfilename) as f: labeltext = [x.strip() for x in f] diff --git a/tests/test-data/compare/labels_from-test.csv b/tests/test-data/compare/labels_from-test.csv new file mode 100644 index 0000000000..902c045e60 --- /dev/null +++ b/tests/test-data/compare/labels_from-test.csv @@ -0,0 +1,5 @@ +sort_order,md5,label,name,filename,signature_file +4,8a619747693c045afde376263841806b,genome-s10+s11-CHANGED,genome-s10+s11,-,/Users/t/dev/sourmash/tests/test-data/genome-s10+s11.sig +3,ff511252a80bb9a7dbb0acf62626e123,genome-s12-CHANGED,genome-s12,genome-s12.fa.gz,/Users/t/dev/sourmash/tests/test-data/genome-s12.fa.gz.sig +2,1437d8eae64bad9bdc8d13e1daa0a43e,genome-s11-CHANGED,genome-s11,genome-s11.fa.gz,/Users/t/dev/sourmash/tests/test-data/genome-s11.fa.gz.sig +1,4cb3290263eba24548f5bef38bcaefc9,genome-s10-CHANGED,genome-s10,genome-s10.fa.gz,/Users/t/dev/sourmash/tests/test-data/genome-s10.fa.gz.sig \ No newline at end of file diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index 7aaac0446e..ed8cc80b45 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -151,6 +151,7 @@ def test_compare_serial(runtmp): testsigs = utils.get_test_data("genome-s1*.sig") testsigs = glob.glob(testsigs) + assert len(testsigs) == 4 c.run_sourmash("compare", "-o", "cmp", "-k", "21", "--dna", *testsigs) @@ -1252,7 +1253,7 @@ def test_plot_override_labeltext(runtmp): print(runtmp.last_result.out) - assert "loading labels from new.labels.txt" in runtmp.last_result.err + assert "loading labels from text file 'new.labels.txt'" in runtmp.last_result.err expected = """\ 0\ta @@ -1291,7 +1292,7 @@ def test_plot_override_labeltext_fail(runtmp): print(runtmp.last_result.out) print(runtmp.last_result.err) assert runtmp.last_result.status != 0 - assert "loading labels from new.labels.txt" in runtmp.last_result.err + assert "loading labels from text file 'new.labels.txt'" in runtmp.last_result.err assert "3 labels != matrix size, exiting" in runtmp.last_result.err @@ -1406,6 +1407,117 @@ def test_plot_subsample_2(runtmp): assert expected in runtmp.last_result.out +def test_compare_and_plot_labels_from_to(runtmp): + # test doing compare --labels-to and plot --labels-from. + testdata1 = utils.get_test_data("genome-s10.fa.gz.sig") + testdata2 = utils.get_test_data("genome-s11.fa.gz.sig") + testdata3 = utils.get_test_data("genome-s12.fa.gz.sig") + testdata4 = utils.get_test_data("genome-s10+s11.sig") + + labels_csv = runtmp.output("label.csv") + + runtmp.run_sourmash( + "compare", + testdata1, + testdata2, + testdata3, + testdata4, + "-o", + "cmp", + "-k", + "21", + "--dna", + "--labels-to", + labels_csv, + ) + + runtmp.sourmash("plot", "cmp", "--labels-from", labels_csv) + + print(runtmp.last_result.out) + + assert "loading labels from CSV file" in runtmp.last_result.err + + expected = """\ +0\tgenome-s10 +1\tgenome-s11 +2\tgenome-s12 +3\tgenome-s10+s11""" + assert expected in runtmp.last_result.out + + +def test_compare_and_plot_labels_from_changed(runtmp): + # test 'plot --labels-from' with changed labels + testdata1 = utils.get_test_data("genome-s10.fa.gz.sig") + testdata2 = utils.get_test_data("genome-s11.fa.gz.sig") + testdata3 = utils.get_test_data("genome-s12.fa.gz.sig") + testdata4 = utils.get_test_data("genome-s10+s11.sig") + + labels_csv = utils.get_test_data("compare/labels_from-test.csv") + + runtmp.run_sourmash( + "compare", + testdata1, + testdata2, + testdata3, + testdata4, + "-o", + "cmp", + "-k", + "21", + "--dna", + ) + + runtmp.sourmash("plot", "cmp", "--labels-from", labels_csv) + + print(runtmp.last_result.out) + + assert "loading labels from CSV file" in runtmp.last_result.err + + expected = """\ +0\tgenome-s10-CHANGED +1\tgenome-s11-CHANGED +2\tgenome-s12-CHANGED +3\tgenome-s10+s11-CHANGED""" + assert expected in runtmp.last_result.out + + +def test_compare_and_plot_labels_from_error(runtmp): + # 'plot --labels-from ... --labeltext ...' should fail + testdata1 = utils.get_test_data("genome-s10.fa.gz.sig") + testdata2 = utils.get_test_data("genome-s11.fa.gz.sig") + testdata3 = utils.get_test_data("genome-s12.fa.gz.sig") + testdata4 = utils.get_test_data("genome-s10+s11.sig") + + labels_csv = utils.get_test_data("compare/labels_from-test.csv") + + runtmp.run_sourmash( + "compare", + testdata1, + testdata2, + testdata3, + testdata4, + "-o", + "cmp", + "-k", + "21", + "--dna", + ) + + with pytest.raises(SourmashCommandFailed): + runtmp.sourmash( + "plot", + "cmp", + "--labels-from", + labels_csv, + "--labeltext", + labels_csv, + fail_ok=True, + ) + + err = runtmp.last_result.err + assert "ERROR: cannot supply both --labeltext and --labels-from" in err + + @utils.in_tempdir def test_search_query_sig_does_not_exist(c): testdata1 = utils.get_test_data("short.fa")