From 1229dc1d6664004e580e3b80d4c13a8ae12a19be Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Wed, 6 Apr 2022 06:54:11 -0400 Subject: [PATCH] [MRG] better handle some pickfile errors (#1924) * better handle some pickfile errors * Update tests/test_cmd_signature.py Co-authored-by: Tessa Pierce Ward Co-authored-by: Tessa Pierce Ward --- src/sourmash/picklist.py | 8 ++++++- src/sourmash/sourmash_args.py | 8 +++---- tests/test_cmd_signature.py | 39 +++++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/src/sourmash/picklist.py b/src/sourmash/picklist.py index fc306782b6..f48377e4cb 100644 --- a/src/sourmash/picklist.py +++ b/src/sourmash/picklist.py @@ -1,5 +1,6 @@ "Picklist code for extracting subsets of signatures." import csv +import os from enum import Enum # set up preprocessing functions for column stuff @@ -143,18 +144,23 @@ def load(self, pickfile, column_name): "load pickset, return num empty vals, and set of duplicate vals." pickset = self.init() + if not os.path.exists(pickfile) or not os.path.isfile(pickfile): + raise ValueError(f"pickfile '{pickfile}' must exist and be a regular file") + n_empty_val = 0 dup_vals = set() with open(pickfile, newline='') as csvfile: x = csvfile.readline() # skip leading comment line in case there's a manifest header - if x[0] == '#': + if not x or x[0] == '#': pass else: csvfile.seek(0) r = csv.DictReader(csvfile) + if not r.fieldnames: + raise ValueError(f"empty or improperly formatted pickfile '{pickfile}'") if column_name not in r.fieldnames: raise ValueError(f"column '{column_name}' not in pickfile '{pickfile}'") diff --git a/src/sourmash/sourmash_args.py b/src/sourmash/sourmash_args.py index dc001013e3..764710a494 100644 --- a/src/sourmash/sourmash_args.py +++ b/src/sourmash/sourmash_args.py @@ -126,15 +126,15 @@ def load_picklist(args): if args.picklist: try: picklist = SignaturePicklist.from_picklist_args(args.picklist) + + notify(f"picking column '{picklist.column_name}' of type '{picklist.coltype}' from '{picklist.pickfile}'") + + n_empty_val, dup_vals = picklist.load(picklist.pickfile, picklist.column_name) except ValueError as exc: error("ERROR: could not load picklist.") error(str(exc)) sys.exit(-1) - notify(f"picking column '{picklist.column_name}' of type '{picklist.coltype}' from '{picklist.pickfile}'") - - n_empty_val, dup_vals = picklist.load(picklist.pickfile, picklist.column_name) - notify(f"loaded {len(picklist.pickset)} distinct values into picklist.") if n_empty_val: notify(f"WARNING: {n_empty_val} empty values in column '{picklist.column_name}' in picklist file") diff --git a/tests/test_cmd_signature.py b/tests/test_cmd_signature.py index b09d05076b..cd93349f7e 100644 --- a/tests/test_cmd_signature.py +++ b/tests/test_cmd_signature.py @@ -1627,6 +1627,45 @@ def test_sig_extract_7_no_ksize(c): assert len(siglist) == 3 +def test_sig_extract_8_empty_picklist_fail(runtmp): + # what happens with an empty picklist? + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + # make empty picklist + picklist_csv = runtmp.output('pick.csv') + with open(picklist_csv, 'w', newline='') as csvfp: + pass + + picklist_arg = f"{picklist_csv}:md5full:md5" + + with pytest.raises(SourmashCommandFailed): + runtmp.sourmash('sig', 'extract', sig47, sig63, '--picklist', picklist_arg) + + err = runtmp.last_result.err + print(err) + + assert "empty or improperly formatted pickfile" in err + + +def test_sig_extract_8_nofile_picklist_fail(runtmp): + # what happens when picklist file does not exist? + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + # picklist file does not exist + picklist_csv = runtmp.output('pick.csv') + picklist_arg = f"{picklist_csv}:md5full:md5" + + with pytest.raises(SourmashCommandFailed): + runtmp.sourmash('sig', 'extract', sig47, sig63, '--picklist', picklist_arg) + + err = runtmp.last_result.err + print(err) + + assert "must exist and be a regular file" in err + + def test_sig_extract_8_picklist_md5(runtmp): # extract 47 from 47, using a picklist w/full md5 sig47 = utils.get_test_data('47.fa.sig')