From 4080b5d566765e90e0088fe851ea3ff2dd0239a8 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Fri, 15 Apr 2022 11:00:17 -0700 Subject: [PATCH] [MRG] provide "protocol" tests for `Index`, `CollectionManifest`, and `LCA_Database` classes (#1936) * add LCA database test for tricky ordering * add test for jaccard ordering to SBTs * add test_index_protocol * add tests of indices after save/load * match Index definition of __len__ in sbt * more index tests * add some generic manifest tests * define abstract base class for CollectionManifest * fix GTDB example, sigh * test hashval_to_idx * add actual test for min num in rankinfo * update 'get_lineage_assignments' in lca_db * update comment * make lid_to_idx and idx_to_ident private * moar comment * add, then hide, RevIndex test * update the LCA_Database protocol * finish testing the rest of the Index classes * cleanup * upd * backport 08ac110dfad4afb76 * remove test for now-implemented func * switch away from a row tuple in CollectionManifest * more clearly separate internals of LCA_Database from public API * add saved/loaded manifest * add test coverage for exceptions in LazyLoadedIndex * add docstrings to manifest code * add docstrings / comments * use names in namedtuple; add containment test * add numerical values to jaccard order tests * add required_keys check * fix diagnostic output during sourmash index #1949 * fix test name --- src/sourmash/index/__init__.py | 9 +- src/sourmash/index/revindex.py | 2 +- src/sourmash/lca/command_index.py | 4 +- src/sourmash/lca/command_rankinfo.py | 13 +- src/sourmash/lca/lca_db.py | 126 +++--- src/sourmash/manifest.py | 116 +++-- src/sourmash/sbt.py | 8 +- tests/test-data/prot/gtdb-subset-lineages.csv | 2 +- tests/test_cmd_signature_fileinfo.py | 4 +- tests/test_index.py | 131 +----- tests/test_index_protocol.py | 416 ++++++++++++++++++ tests/test_lca.py | 146 +++--- tests/test_lca_db_protocol.py | 108 +++++ tests/test_manifest_protocol.py | 166 +++++++ tests/test_sbt.py | 2 + 15 files changed, 969 insertions(+), 284 deletions(-) create mode 100644 tests/test_index_protocol.py create mode 100644 tests/test_lca_db_protocol.py create mode 100644 tests/test_manifest_protocol.py diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index 5fa521db66..770c677f5c 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -491,7 +491,8 @@ def __bool__(self): return False def __len__(self): - raise NotImplementedError + db = self.db.select(**self.selection_dict) + return len(db) def insert(self, node): raise NotImplementedError @@ -1064,6 +1065,12 @@ class LazyLoadedIndex(Index): """ def __init__(self, filename, manifest): "Create an Index with given filename and manifest." + if not os.path.exists(filename): + raise ValueError(f"'{filename}' must exist when creating LazyLoadedIndex") + + if manifest is None: + raise ValueError("manifest cannot be None") + self.filename = filename self.manifest = manifest diff --git a/src/sourmash/index/revindex.py b/src/sourmash/index/revindex.py index 8951dbc759..2f7074b53f 100644 --- a/src/sourmash/index/revindex.py +++ b/src/sourmash/index/revindex.py @@ -146,7 +146,7 @@ def save(self, path): def load(cls, location): pass - def select(self, ksize=None, moltype=None): + def select(self, ksize=None, moltype=None, **kwargs): if self.template: if ksize: self.template.ksize = ksize diff --git a/src/sourmash/lca/command_index.py b/src/sourmash/lca/command_index.py index 50c70db918..5393bfa316 100644 --- a/src/sourmash/lca/command_index.py +++ b/src/sourmash/lca/command_index.py @@ -277,10 +277,10 @@ def index(args): sys.exit(1) # check -- did the signatures we found have any hashes? - if not db.hashval_to_idx: + if not db.hashvals: error('ERROR: no hash values found - are there any signatures?') sys.exit(1) - notify(f'loaded {len(db.hashval_to_idx)} hashes at ksize={args.ksize} scaled={args.scaled}') + notify(f'loaded {len(db.hashvals)} hashes at ksize={args.ksize} scaled={args.scaled}') if picklist: sourmash_args.report_picklist(args, picklist) diff --git a/src/sourmash/lca/command_rankinfo.py b/src/sourmash/lca/command_rankinfo.py index 081f1bf481..ec8aba4a16 100644 --- a/src/sourmash/lca/command_rankinfo.py +++ b/src/sourmash/lca/command_rankinfo.py @@ -19,15 +19,10 @@ def make_lca_counts(dblist, min_num=0): # gather all hashvalue assignments from across all the databases assignments = defaultdict(set) for lca_db in dblist: - for hashval, idx_list in lca_db.hashval_to_idx.items(): - if min_num and len(idx_list) < min_num: - continue - - for idx in idx_list: - lid = lca_db.idx_to_lid.get(idx) - if lid is not None: - lineage = lca_db.lid_to_lineage[lid] - assignments[hashval].add(lineage) + for hashval in lca_db.hashvals: + lineages = lca_db.get_lineage_assignments(hashval, min_num) + if lineages: + assignments[hashval].update(lineages) # now convert to trees -> do LCA & counts counts = defaultdict(int) diff --git a/src/sourmash/lca/lca_db.py b/src/sourmash/lca/lca_db.py index 8f88d0c11f..cda1208a60 100644 --- a/src/sourmash/lca/lca_db.py +++ b/src/sourmash/lca/lca_db.py @@ -39,21 +39,21 @@ class LCA_Database(Index): the `ident` keyword argument in `insert`. Integer `idx` indices can be used as keys in dictionary attributes: - * `idx_to_lid`, to get an (optional) lineage index. - * `idx_to_ident`, to retrieve the unique string identifier for that `idx`. + * `_idx_to_lid`, to get an (optional) lineage index. + * `_idx_to_ident`, to retrieve the unique string identifier for that `idx`. Integer `lid` indices can be used as keys in dictionary attributes: - * `lid_to_idx`, to get a set of `idx` with that lineage. - * `lid_to_lineage`, to get a lineage for that `lid`. + * `_lid_to_idx`, to get a set of `idx` with that lineage. + * `_lid_to_lineage`, to get a lineage for that `lid`. - `lineage_to_lid` is a dictionary with tuples of LineagePair as keys, + `_lineage_to_lid` is a dictionary with tuples of LineagePair as keys, `lid` as values. - `ident_to_name` is a dictionary from unique str identifer to a name. + `_ident_to_name` is a dictionary from unique str identifer to a name. - `ident_to_idx` is a dictionary from unique str identifer to integer `idx`. + `_ident_to_idx` is a dictionary from unique str identifer to integer `idx`. - `hashval_to_idx` is a dictionary from individual hash values to sets of + `_hashval_to_idx` is a dictionary from individual hash values to sets of `idx`. """ is_database = True @@ -70,12 +70,12 @@ def __init__(self, ksize, scaled, moltype='DNA'): self._next_index = 0 self._next_lid = 0 - self.ident_to_name = {} - self.ident_to_idx = {} - self.idx_to_lid = {} - self.lineage_to_lid = {} - self.lid_to_lineage = {} - self.hashval_to_idx = defaultdict(set) + self._ident_to_name = {} + self._ident_to_idx = {} + self._idx_to_lid = {} + self._lineage_to_lid = {} + self._lid_to_lineage = {} + self._hashval_to_idx = defaultdict(set) self.picklists = [] @property @@ -91,7 +91,7 @@ def _invalidate_cache(self): def _get_ident_index(self, ident, fail_on_duplicate=False): "Get (create if nec) a unique int id, idx, for each identifier." - idx = self.ident_to_idx.get(ident) + idx = self._ident_to_idx.get(ident) if fail_on_duplicate: assert idx is None # should be no duplicate identities @@ -99,14 +99,14 @@ def _get_ident_index(self, ident, fail_on_duplicate=False): idx = self._next_index self._next_index += 1 - self.ident_to_idx[ident] = idx + self._ident_to_idx[ident] = idx return idx def _get_lineage_id(self, lineage): "Get (create if nec) a unique lineage ID for each LineagePair tuples." # does one exist already? - lid = self.lineage_to_lid.get(lineage) + lid = self._lineage_to_lid.get(lineage) # nope - create one. Increment next_lid. if lid is None: @@ -114,8 +114,8 @@ def _get_lineage_id(self, lineage): self._next_lid += 1 # build mappings - self.lineage_to_lid[lineage] = lid - self.lid_to_lineage[lid] = lineage + self._lineage_to_lid[lineage] = lid + self._lid_to_lineage[lid] = lineage return lid @@ -147,14 +147,14 @@ def insert(self, sig, ident=None, lineage=None): if not ident: ident = str(sig) - if ident in self.ident_to_name: + if ident in self._ident_to_name: raise ValueError("signature '{}' is already in this LCA db.".format(ident)) # before adding, invalide any caching from @cached_property self._invalidate_cache() # store full name - self.ident_to_name[ident] = sig.name + self._ident_to_name[ident] = sig.name # identifier -> integer index (idx) idx = self._get_ident_index(ident, fail_on_duplicate=True) @@ -166,12 +166,12 @@ def insert(self, sig, ident=None, lineage=None): lid = self._get_lineage_id(lineage) # map idx to lid as well. - self.idx_to_lid[idx] = lid + self._idx_to_lid[idx] = lid except TypeError: raise ValueError('lineage cannot be used as a key?!') for hashval in minhash.hashes: - self.hashval_to_idx[hashval].add(idx) + self._hashval_to_idx[hashval].add(idx) return len(minhash) @@ -290,8 +290,8 @@ def load(cls, db_name): vv = tuple(vv) lid_to_lineage[int(k)] = vv lineage_to_lid[vv] = int(k) - db.lid_to_lineage = lid_to_lineage - db.lineage_to_lid = lineage_to_lid + db._lid_to_lineage = lid_to_lineage + db._lineage_to_lid = lineage_to_lid # convert hashval -> lineage index keys to integers (looks like # JSON doesn't have a 64 bit type so stores them as strings) @@ -300,21 +300,21 @@ def load(cls, db_name): for k, v in hashval_to_idx_2.items(): hashval_to_idx[int(k)] = v - db.hashval_to_idx = hashval_to_idx + db._hashval_to_idx = hashval_to_idx - db.ident_to_name = load_d['ident_to_name'] - db.ident_to_idx = load_d['ident_to_idx'] + db._ident_to_name = load_d['ident_to_name'] + db._ident_to_idx = load_d['ident_to_idx'] - db.idx_to_lid = {} + db._idx_to_lid = {} for k, v in load_d['idx_to_lid'].items(): - db.idx_to_lid[int(k)] = v + db._idx_to_lid[int(k)] = v - if db.ident_to_idx: - db._next_index = max(db.ident_to_idx.values()) + 1 + if db._ident_to_idx: + db._next_index = max(db._ident_to_idx.values()) + 1 else: db._next_index = 0 - if db.idx_to_lid: - db._next_lid = max(db.idx_to_lid.values()) + 1 + if db._idx_to_lid: + db._next_lid = max(db._idx_to_lid.values()) + 1 else: db._next_lid = 0 @@ -345,18 +345,18 @@ def save(self, db_name): # convert lineage internals from tuples to dictionaries d = OrderedDict() - for k, v in self.lid_to_lineage.items(): + for k, v in self._lid_to_lineage.items(): d[k] = dict([ (vv.rank, vv.name) for vv in v ]) save_d['lid_to_lineage'] = d # convert values from sets to lists, so that JSON knows how to save save_d['hashval_to_idx'] = \ - dict((k, list(v)) for (k, v) in self.hashval_to_idx.items()) + dict((k, list(v)) for (k, v) in self._hashval_to_idx.items()) - save_d['ident_to_name'] = self.ident_to_name - save_d['ident_to_idx'] = self.ident_to_idx - save_d['idx_to_lid'] = self.idx_to_lid - save_d['lid_to_lineage'] = self.lid_to_lineage + save_d['ident_to_name'] = self._ident_to_name + save_d['ident_to_idx'] = self._ident_to_idx + save_d['idx_to_lid'] = self._idx_to_lid + save_d['lid_to_lineage'] = self._lid_to_lineage json.dump(save_d, fp) @@ -378,27 +378,45 @@ def downsample_scaled(self, scaled): # filter out all hashes over max_hash in value. new_hashvals = {} - for k, v in self.hashval_to_idx.items(): + for k, v in self._hashval_to_idx.items(): if k < max_hash: new_hashvals[k] = v - self.hashval_to_idx = new_hashvals + self._hashval_to_idx = new_hashvals self.scaled = scaled - def get_lineage_assignments(self, hashval): + @property + def hashvals(self): + "Return all hashvals stored in this database." + return self._hashval_to_idx.keys() + + def get_lineage_assignments(self, hashval, min_num=None): """ Get a list of lineages for this hashval. """ x = [] - idx_list = self.hashval_to_idx.get(hashval, []) + idx_list = self._hashval_to_idx.get(hashval, []) + + if min_num and len(idx_list) < min_num: + return [] + for idx in idx_list: - lid = self.idx_to_lid.get(idx, None) + lid = self._idx_to_lid.get(idx, None) if lid is not None: - lineage = self.lid_to_lineage[lid] + lineage = self._lid_to_lineage[lid] x.append(lineage) return x + def get_identifiers_for_hashval(self, hashval): + """ + Get a list of identifiers for signatures containing this hashval + """ + idx_list = self._hashval_to_idx.get(hashval, []) + + for idx in idx_list: + yield self._idx_to_ident[idx] + @cached_property def _signatures(self): "Create a _signatures member dictionary that contains {idx: sigobj}." @@ -422,7 +440,7 @@ def _signatures(self): temp_vals = defaultdict(list) # invert the hashval_to_idx dictionary - for (hashval, idlist) in self.hashval_to_idx.items(): + for (hashval, idlist) in self._hashval_to_idx.items(): for idx in idlist: temp_hashes = temp_vals[idx] temp_hashes.append(hashval) @@ -445,8 +463,8 @@ def _signatures(self): sigd = {} for idx, mh in mhd.items(): - ident = self.idx_to_ident[idx] - name = self.ident_to_name[ident] + ident = self._idx_to_ident[idx] + name = self._ident_to_name[ident] ss = SourmashSignature(mh, name=name) if passes_all_picklists(ss, self.picklists): @@ -481,7 +499,7 @@ def find(self, search_fn, query, **kwargs): c = Counter() query_hashes = set(query_mh.hashes) for hashval in query_hashes: - idx_list = self.hashval_to_idx.get(hashval, []) + idx_list = self._hashval_to_idx.get(hashval, []) for idx in idx_list: c[idx] += 1 @@ -523,16 +541,16 @@ def find(self, search_fn, query, **kwargs): yield IndexSearchResult(score, subj, self.location) @cached_property - def lid_to_idx(self): + def _lid_to_idx(self): d = defaultdict(set) - for idx, lid in self.idx_to_lid.items(): + for idx, lid in self._idx_to_lid.items(): d[lid].add(idx) return d @cached_property - def idx_to_ident(self): + def _idx_to_ident(self): d = defaultdict(set) - for ident, idx in self.ident_to_idx.items(): + for ident, idx in self._ident_to_idx.items(): assert idx not in d d[idx] = ident return d diff --git a/src/sourmash/manifest.py b/src/sourmash/manifest.py index 78c1a139ff..2447d94067 100644 --- a/src/sourmash/manifest.py +++ b/src/sourmash/manifest.py @@ -3,11 +3,12 @@ """ import csv import ast +from abc import abstractmethod from sourmash.picklist import SignaturePicklist -class CollectionManifest: +class BaseCollectionManifest: """ Signature metadata for a collection of signatures. @@ -25,37 +26,6 @@ class CollectionManifest: 'scaled', 'n_hashes', 'with_abundance', 'name', 'filename') - def __init__(self, rows): - "Initialize from an iterable of metadata dictionaries." - self.rows = () - self._md5_set = set() - - self._add_rows(rows) - - def _add_rows(self, rows): - self.rows += tuple(rows) - - # maintain a fast lookup table for md5sums - md5set = self._md5_set - for row in self.rows: - md5set.add(row['md5']) - - def __iadd__(self, other): - self._add_rows(other.rows) - return self - - def __add__(self, other): - return CollectionManifest(self.rows + other.rows) - - def __bool__(self): - return bool(self.rows) - - def __len__(self): - return len(self.rows) - - def __eq__(self, other): - return self.rows == other.rows - @classmethod def load_from_filename(cls, filename): with open(filename, newline="") as fp: @@ -154,11 +124,91 @@ def create_manifest(cls, locations_iter, *, include_signature=True): """ manifest_list = [] for ss, location in locations_iter: - row = cls.make_manifest_row(ss, location, include_signature=True) + row = cls.make_manifest_row(ss, location, + include_signature=include_signature) manifest_list.append(row) return cls(manifest_list) + ## implement me + @abstractmethod + def __add__(self, other): + "Add two manifests" + + @abstractmethod + def __bool__(self): + "Test if manifest is empty" + + @abstractmethod + def __len__(self): + "Get number of entries in manifest" + + @abstractmethod + def __eq__(self, other): + "Check for equality of manifest based on rows" + + @abstractmethod + def select_to_manifest(self, **kwargs): + "Select compatible signatures" + + @abstractmethod + def filter_rows(self, row_filter_fn): + "Filter rows based on a pattern matching function." + + @abstractmethod + def filter_on_columns(self, col_filter_fn, col_names): + "Filter on column values." + + @abstractmethod + def locations(self): + "Return a list of distinct locations" + + @abstractmethod + def __contains__(self, ss): + "Determine if a particular SourmashSignature is in this manifest." + + @abstractmethod + def to_picklist(self): + "Convert manifest to a picklist." + + +class CollectionManifest(BaseCollectionManifest): + """ + An in-memory manifest that simply stores the rows in a list. + """ + def __init__(self, rows): + "Initialize from an iterable of metadata dictionaries." + self.rows = [] + self._md5_set = set() + + self._add_rows(rows) + + def _add_rows(self, rows): + self.rows.extend(rows) + + # maintain a fast check for md5sums for __contains__ check. + md5set = self._md5_set + for row in self.rows: + md5set.add(row['md5']) + + def __iadd__(self, other): + self._add_rows(other.rows) + return self + + def __add__(self, other): + mf = CollectionManifest(self.rows) + mf._add_rows(other.rows) + return mf + + def __bool__(self): + return bool(self.rows) + + def __len__(self): + return len(self.rows) + + def __eq__(self, other): + return self.rows == other.rows + def _select(self, *, ksize=None, moltype=None, scaled=0, num=0, containment=False, abund=None, picklist=None): """Yield manifest rows for sigs that match the specified requirements. diff --git a/src/sourmash/sbt.py b/src/sourmash/sbt.py index bb001fd940..72cad16c30 100644 --- a/src/sourmash/sbt.py +++ b/src/sourmash/sbt.py @@ -670,7 +670,10 @@ def save(self, path, storage=None, sparseness=0.0, structure_only=False): nodes = {} leaves = {} - total_nodes = len(self) + + internal_nodes = set(self._nodes).union(self._missing_nodes) + total_nodes = len(self) + len(internal_nodes) + manifest_rows = [] for n, (i, node) in enumerate(self): if node is None: @@ -1191,8 +1194,7 @@ def _fill_up(self, search_fn, *args, **kwargs): debug("processed {}, in queue {}", processed, len(queue), sep='\r') def __len__(self): - internal_nodes = set(self._nodes).union(self._missing_nodes) - return len(internal_nodes) + len(self._leaves) + return len(self._leaves) def print_dot(self): print(""" diff --git a/tests/test-data/prot/gtdb-subset-lineages.csv b/tests/test-data/prot/gtdb-subset-lineages.csv index bb401a3453..1dd35fcb47 100644 --- a/tests/test-data/prot/gtdb-subset-lineages.csv +++ b/tests/test-data/prot/gtdb-subset-lineages.csv @@ -1,3 +1,3 @@ -accession,gtdb_id,superkingdom,phylum,class,order,family,genus,species +accession,superkingdom,phylum,class,order,family,genus,species GCA_001593935,d__Archaea,p__Crenarchaeota,c__Bathyarchaeia,o__B26-1,f__B26-1,g__B26-1,s__B26-1 sp001593935 GCA_001593925,d__Archaea,p__Crenarchaeota,c__Bathyarchaeia,o__B26-1,f__B26-1,g__B26-1,s__B26-1 sp001593925 diff --git a/tests/test_cmd_signature_fileinfo.py b/tests/test_cmd_signature_fileinfo.py index ee90fc7ba4..6df3aed33f 100644 --- a/tests/test_cmd_signature_fileinfo.py +++ b/tests/test_cmd_signature_fileinfo.py @@ -124,7 +124,7 @@ def test_fileinfo_3_sbt_zip(runtmp): location: protein.sbt.zip is database? yes has manifest? yes -num signatures: 3 +num signatures: 2 total hashes: 8214 summary of sketches: 2 sketches with protein, k=19, scaled=100 8214 total hashes @@ -290,7 +290,7 @@ def test_fileinfo_7_sbt_json(runtmp, db): location: {dbfile} is database? yes has manifest? no -num signatures: 13 +num signatures: 7 total hashes: 3500 summary of sketches: 7 sketches with DNA, k=31, num=500 3500 total hashes diff --git a/tests/test_index.py b/tests/test_index.py index c0ee430ea2..24468db9bf 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -89,84 +89,6 @@ def test_simple_index(n_children): assert tree_found == set(linear_found) -def test_linear_index_search(): - # test LinearIndex searching - all in memory - sig2 = utils.get_test_data('2.fa.sig') - sig47 = utils.get_test_data('47.fa.sig') - sig63 = utils.get_test_data('63.fa.sig') - - ss2 = sourmash.load_one_signature(sig2, ksize=31) - ss47 = sourmash.load_one_signature(sig47) - ss63 = sourmash.load_one_signature(sig63) - - lidx = LinearIndex() - lidx.insert(ss2) - lidx.insert(ss47) - lidx.insert(ss63) - - # now, search for sig2 - sr = lidx.search(ss2, threshold=1.0) - print([s[1].name for s in sr]) - assert len(sr) == 1 - assert sr[0][1] == ss2 - - # search for sig47 with lower threshold; search order not guaranteed. - sr = lidx.search(ss47, threshold=0.1) - print([s[1].name for s in sr]) - assert len(sr) == 2 - sr.sort(key=lambda x: -x[0]) - assert sr[0][1] == ss47 - assert sr[1][1] == ss63 - - # search for sig63 with lower threshold; search order not guaranteed. - sr = lidx.search(ss63, threshold=0.1) - print([s[1].name for s in sr]) - assert len(sr) == 2 - sr.sort(key=lambda x: -x[0]) - assert sr[0][1] == ss63 - assert sr[1][1] == ss47 - - # search for sig63 with high threshold => 1 match - sr = lidx.search(ss63, threshold=0.8) - print([s[1].name for s in sr]) - assert len(sr) == 1 - sr.sort(key=lambda x: -x[0]) - assert sr[0][1] == ss63 - - -def test_linear_index_prefetch(): - # check that prefetch does basic things right: - sig2 = utils.get_test_data('2.fa.sig') - sig47 = utils.get_test_data('47.fa.sig') - sig63 = utils.get_test_data('63.fa.sig') - - ss2 = sourmash.load_one_signature(sig2, ksize=31) - ss47 = sourmash.load_one_signature(sig47) - ss63 = sourmash.load_one_signature(sig63) - - lidx = LinearIndex() - lidx.insert(ss2) - lidx.insert(ss47) - lidx.insert(ss63) - - # search for ss2 - results = [] - for result in lidx.prefetch(ss2, threshold_bp=0): - results.append(result) - - assert len(results) == 1 - assert results[0].signature == ss2 - - # search for ss47 - expect two results - results = [] - for result in lidx.prefetch(ss47, threshold_bp=0): - results.append(result) - - assert len(results) == 2 - assert results[0].signature == ss47 - assert results[1].signature == ss63 - - def test_linear_index_prefetch_empty(): # check that an exception is raised upon for an empty LinearIndex sig2 = utils.get_test_data('2.fa.sig') @@ -219,32 +141,6 @@ def minhash(self): assert "don't touch me!" in str(e.value) -def test_linear_index_gather(): - # test LinearIndex gather - sig2 = utils.get_test_data('2.fa.sig') - sig47 = utils.get_test_data('47.fa.sig') - sig63 = utils.get_test_data('63.fa.sig') - - ss2 = sourmash.load_one_signature(sig2, ksize=31) - ss47 = sourmash.load_one_signature(sig47) - ss63 = sourmash.load_one_signature(sig63) - - lidx = LinearIndex() - lidx.insert(ss2) - lidx.insert(ss47) - lidx.insert(ss63) - - matches = lidx.gather(ss2) - assert len(matches) == 1 - assert matches[0][0] == 1.0 - assert matches[0][1] == ss2 - - matches = lidx.gather(ss47) - assert len(matches) == 1 - assert matches[0][0] == 1.0 - assert matches[0][1] == ss47 - - def test_linear_index_search_subj_has_abundance(): # check that search signatures in the index are flattened appropriately. queryfile = utils.get_test_data('47.fa.sig') @@ -2347,15 +2243,6 @@ def test_lazy_index_4_bool(): assert lazy -def test_lazy_index_5_len(): - # test some basic features of LazyLinearIndex - lidx = LinearIndex() - lazy = LazyLinearIndex(lidx) - - with pytest.raises(NotImplementedError): - len(lazy) - - def test_lazy_index_wraps_multi_index_location(): # check that 'location' works fine when MultiIndex is wrapped by # LazyLinearIndex. @@ -2456,6 +2343,24 @@ def test_lazy_loaded_index_3_find(runtmp): assert len(x) == 0 +def test_lazy_loaded_index_4_nofile(runtmp): + # test check for filename must exist + with pytest.raises(ValueError) as exc: + index.LazyLoadedIndex(runtmp.output('xyz'), True) + + assert "must exist when creating" in str(exc) + + +def test_lazy_loaded_index_4_nomanifest(runtmp): + # test check for empty manifest + sig2 = utils.get_test_data("2.fa.sig") + + with pytest.raises(ValueError) as exc: + index.LazyLoadedIndex(sig2, None) + + assert "manifest cannot be None" in str(exc) + + def test_revindex_index_search(): # confirm that RevIndex works sig2 = utils.get_test_data("2.fa.sig") diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py new file mode 100644 index 0000000000..ff08d5dc46 --- /dev/null +++ b/tests/test_index_protocol.py @@ -0,0 +1,416 @@ +""" +Tests for the 'Index' class and protocol. All Index classes should support +this functionality. +""" + +import pytest + +import sourmash +from sourmash import SourmashSignature +from sourmash.index import (LinearIndex, ZipFileLinearIndex, + LazyLinearIndex, MultiIndex, + StandaloneManifestIndex, LazyLoadedIndex) +from sourmash.index.revindex import RevIndex +from sourmash.sbt import SBT, GraphFactory +from sourmash.manifest import CollectionManifest +from sourmash.lca.lca_db import LCA_Database + +import sourmash_tst_utils as utils + + +def _load_three_sigs(): + # utility function - load & return these three sigs. + sig2 = utils.get_test_data('2.fa.sig') + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47) + ss63 = sourmash.load_one_signature(sig63) + + return [ss2, ss47, ss63] + + +def build_linear_index(runtmp): + ss2, ss47, ss63 = _load_three_sigs() + + lidx = LinearIndex() + lidx.insert(ss2) + lidx.insert(ss47) + lidx.insert(ss63) + + return lidx + + +def build_lazy_linear_index(runtmp): + lidx = build_linear_index(runtmp) + return LazyLinearIndex(lidx) + + +def build_sbt_index(runtmp): + ss2, ss47, ss63 = _load_three_sigs() + + factory = GraphFactory(5, 100, 3) + root = SBT(factory, d=2) + + root.insert(ss2) + root.insert(ss47) + root.insert(ss63) + + return root + + +def build_sbt_index_save_load(runtmp): + root = build_sbt_index(runtmp) + out = runtmp.output('xyz.sbt.zip') + root.save(out) + + return sourmash.load_file_as_index(out) + + +def build_zipfile_index(runtmp): + from sourmash.sourmash_args import SaveSignatures_ZipFile + + location = runtmp.output('index.zip') + with SaveSignatures_ZipFile(location) as save_sigs: + for ss in _load_three_sigs(): + save_sigs.add(ss) + + idx = ZipFileLinearIndex.load(location) + return idx + + +def build_multi_index(runtmp): + siglist = _load_three_sigs() + lidx = LinearIndex(siglist) + + mi = MultiIndex.load([lidx], [None], None) + return mi + + +def build_standalone_manifest_index(runtmp): + sig2 = utils.get_test_data('2.fa.sig') + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47) + ss63 = sourmash.load_one_signature(sig63) + + siglist = [(ss2, sig2), (ss47, sig47), (ss63, sig63)] + + rows = [] + rows.extend((CollectionManifest.make_manifest_row(ss, loc) for ss, loc in siglist )) + mf = CollectionManifest(rows) + mf_filename = runtmp.output("mf.csv") + + mf.write_to_filename(mf_filename) + + idx = StandaloneManifestIndex.load(mf_filename) + return idx + + +def build_lca_index(runtmp): + siglist = _load_three_sigs() + db = LCA_Database(31, 1000, 'DNA') + for ss in siglist: + db.insert(ss) + + return db + + +def build_lca_index_save_load(runtmp): + db = build_lca_index(runtmp) + outfile = runtmp.output('db.lca.json') + db.save(outfile) + + return sourmash.load_file_as_index(outfile) + + +def build_lazy_loaded_index(runtmp): + db = build_lca_index(runtmp) + outfile = runtmp.output('db.lca.json') + db.save(outfile) + + mf = CollectionManifest.create_manifest(db._signatures_with_internal()) + return LazyLoadedIndex(outfile, mf) + + +def build_revindex(runtmp): + ss2, ss47, ss63 = _load_three_sigs() + + lidx = RevIndex(template=ss2.minhash) + lidx.insert(ss2) + lidx.insert(ss47) + lidx.insert(ss63) + + return lidx + + +# +# create a fixture 'index_obj' that is parameterized by all of these +# building functions. +# + +@pytest.fixture(params=[build_linear_index, + build_lazy_linear_index, + build_sbt_index, + build_zipfile_index, + build_multi_index, + build_standalone_manifest_index, + build_lca_index, + build_sbt_index_save_load, + build_lca_index_save_load, + build_lazy_loaded_index, +# build_revindex, + ] +) +def index_obj(request, runtmp): + build_fn = request.param + + # build on demand + return build_fn(runtmp) + + +### +### generic Index tests go here +### + + +def test_index_search_exact_match(index_obj): + # search for an exact match + ss2, ss47, ss63 = _load_three_sigs() + + sr = index_obj.search(ss2, threshold=1.0) + print([s[1].name for s in sr]) + assert len(sr) == 1 + assert sr[0].signature.minhash == ss2.minhash + assert sr[0].score == 1.0 + + +def test_index_search_lower_threshold(index_obj): + # search at a lower threshold/multiple results with ss47 + ss2, ss47, ss63 = _load_three_sigs() + + sr = index_obj.search(ss47, threshold=0.1) + print([s[1].name for s in sr]) + assert len(sr) == 2 + sr.sort(key=lambda x: -x[0]) + assert sr[0].signature.minhash == ss47.minhash + assert sr[0].score == 1.0 + assert sr[1].signature.minhash == ss63.minhash + assert round(sr[1].score, 2) == 0.32 + + +def test_index_search_lower_threshold_2(index_obj): + # search at a lower threshold/multiple results with ss63 + ss2, ss47, ss63 = _load_three_sigs() + + sr = index_obj.search(ss63, threshold=0.1) + print([s[1].name for s in sr]) + assert len(sr) == 2 + sr.sort(key=lambda x: -x[0]) + assert sr[0].signature.minhash == ss63.minhash + assert sr[0].score == 1.0 + assert sr[1].signature.minhash == ss47.minhash + assert round(sr[1].score, 2) == 0.32 + + +def test_index_search_higher_threshold_2(index_obj): + # search at a higher threshold/one match + ss2, ss47, ss63 = _load_three_sigs() + + # search for sig63 with high threshold => 1 match + sr = index_obj.search(ss63, threshold=0.8) + print([s[1].name for s in sr]) + assert len(sr) == 1 + sr.sort(key=lambda x: -x[0]) + assert sr[0].signature.minhash == ss63.minhash + assert sr[0].score == 1.0 + + +def test_index_search_containment(index_obj): + # search for containment at a low threshold/multiple results with ss63 + ss2, ss47, ss63 = _load_three_sigs() + + sr = index_obj.search(ss63, do_containment=True, threshold=0.1) + print([s[1].name for s in sr]) + assert len(sr) == 2 + sr.sort(key=lambda x: -x[0]) + assert sr[0].signature.minhash == ss63.minhash + assert sr[0].score == 1.0 + assert sr[1].signature.minhash == ss47.minhash + assert round(sr[1].score, 2) == 0.48 + + +def test_index_signatures(index_obj): + # signatures works? + siglist = list(index_obj.signatures()) + + ss2, ss47, ss63 = _load_three_sigs() + assert len(siglist) == 3 + + # check md5sums, since 'in' doesn't always work + md5s = set(( ss.md5sum() for ss in siglist )) + assert ss2.md5sum() in md5s + assert ss47.md5sum() in md5s + assert ss63.md5sum() in md5s + + +def test_index_len(index_obj): + # len works? + assert len(index_obj) == 3 + + +def test_index_bool(index_obj): + # bool works? + assert bool(index_obj) + + +def test_index_select_basic(index_obj): + # select does the basic thing ok + idx = index_obj.select(ksize=31, moltype='DNA', abund=False, + containment=True, scaled=1000, num=0, picklist=None) + + assert len(idx) == 3 + siglist = list(idx.signatures()) + assert len(siglist) == 3 + + # check md5sums, since 'in' doesn't always work + md5s = set(( ss.md5sum() for ss in siglist )) + ss2, ss47, ss63 = _load_three_sigs() + assert ss2.md5sum() in md5s + assert ss47.md5sum() in md5s + assert ss63.md5sum() in md5s + + +def test_index_select_nada(index_obj): + # select works ok when nothing matches! + + # CTB: currently this EITHER raises a ValueError OR returns an empty + # Index object, depending on implementation. :think: + # See: https://github.com/sourmash-bio/sourmash/issues/1940 + try: + idx = index_obj.select(ksize=21) + except ValueError: + idx = LinearIndex([]) + + assert len(idx) == 0 + siglist = list(idx.signatures()) + assert len(siglist) == 0 + + +def test_index_prefetch(index_obj): + # test basic prefetch + ss2, ss47, ss63 = _load_three_sigs() + + # search for ss2 + results = [] + for result in index_obj.prefetch(ss2, threshold_bp=0): + results.append(result) + + assert len(results) == 1 + assert results[0].signature.minhash == ss2.minhash + + # search for ss47 - expect two results + results = [] + for result in index_obj.prefetch(ss47, threshold_bp=0): + results.append(result) + + assert len(results) == 2 + assert results[0].signature.minhash == ss47.minhash + assert results[1].signature.minhash == ss63.minhash + + +def test_index_gather(index_obj): + # test basic gather + ss2, ss47, ss63 = _load_three_sigs() + + matches = index_obj.gather(ss2) + assert len(matches) == 1 + assert matches[0].score == 1.0 + assert matches[0].signature.minhash == ss2.minhash + + matches = index_obj.gather(ss47) + assert len(matches) == 1 + assert matches[0].score == 1.0 + assert matches[0].signature.minhash == ss47.minhash + + +def test_linear_gather_threshold_1(index_obj): + # test gather() method, in some detail + ss2, ss47, ss63 = _load_three_sigs() + + # now construct query signatures with specific numbers of hashes -- + # note, these signatures all have scaled=1000. + + mins = list(sorted(ss2.minhash.hashes)) + new_mh = ss2.minhash.copy_and_clear() + + # query with empty hashes + assert not new_mh + with pytest.raises(ValueError): + index_obj.gather(SourmashSignature(new_mh)) + + # add one hash + new_mh.add_hash(mins.pop()) + assert len(new_mh) == 1 + + results = index_obj.gather(SourmashSignature(new_mh)) + assert len(results) == 1 + containment, match_sig, name = results[0] + assert containment == 1.0 + assert match_sig.minhash == ss2.minhash + + # check with a threshold -> should be no results. + with pytest.raises(ValueError): + index_obj.gather(SourmashSignature(new_mh), threshold_bp=5000) + + # add three more hashes => length of 4 + new_mh.add_hash(mins.pop()) + new_mh.add_hash(mins.pop()) + new_mh.add_hash(mins.pop()) + assert len(new_mh) == 4 + + results = index_obj.gather(SourmashSignature(new_mh)) + assert len(results) == 1 + containment, match_sig, name = results[0] + assert containment == 1.0 + assert match_sig.minhash == ss2.minhash + + # check with a too-high threshold -> should be no results. + with pytest.raises(ValueError): + index_obj.gather(SourmashSignature(new_mh), threshold_bp=5000) + + +def test_gather_threshold_5(index_obj): + # test gather() method, in some detail + ss2, ss47, ss63 = _load_three_sigs() + + # now construct query signatures with specific numbers of hashes -- + # note, these signatures all have scaled=1000. + + mins = list(sorted(ss2.minhash.hashes.keys())) + new_mh = ss2.minhash.copy_and_clear() + + # add five hashes + for i in range(5): + new_mh.add_hash(mins.pop()) + new_mh.add_hash(mins.pop()) + new_mh.add_hash(mins.pop()) + new_mh.add_hash(mins.pop()) + new_mh.add_hash(mins.pop()) + + # should get a result with no threshold (any match at all is returned) + results = index_obj.gather(SourmashSignature(new_mh)) + assert len(results) == 1 + containment, match_sig, name = results[0] + assert containment == 1.0 + assert match_sig.minhash == ss2.minhash + + # now, check with a threshold_bp that should be meet-able. + results = index_obj.gather(SourmashSignature(new_mh), threshold_bp=5000) + assert len(results) == 1 + containment, match_sig, name = results[0] + assert containment == 1.0 + assert match_sig.minhash == ss2.minhash diff --git a/tests/test_lca.py b/tests/test_lca.py index 990de48864..84e6843fa7 100644 --- a/tests/test_lca.py +++ b/tests/test_lca.py @@ -117,23 +117,23 @@ def test_api_create_insert(): lca_db.insert(ss) ident = ss.name - assert len(lca_db.ident_to_name) == 1 - assert ident in lca_db.ident_to_name - assert lca_db.ident_to_name[ident] == ident - assert len(lca_db.ident_to_idx) == 1 - assert lca_db.ident_to_idx[ident] == 0 - assert len(lca_db.hashval_to_idx) == len(ss.minhash) - assert len(lca_db.idx_to_ident) == 1 - assert lca_db.idx_to_ident[0] == ident + assert len(lca_db._ident_to_name) == 1 + assert ident in lca_db._ident_to_name + assert lca_db._ident_to_name[ident] == ident + assert len(lca_db._ident_to_idx) == 1 + assert lca_db._ident_to_idx[ident] == 0 + assert len(lca_db._hashval_to_idx) == len(ss.minhash) + assert len(lca_db._idx_to_ident) == 1 + assert lca_db._idx_to_ident[0] == ident set_of_values = set() - for vv in lca_db.hashval_to_idx.values(): + for vv in lca_db._hashval_to_idx.values(): set_of_values.update(vv) assert len(set_of_values) == 1 assert set_of_values == { 0 } - assert not lca_db.idx_to_lid # no lineage added - assert not lca_db.lid_to_lineage # no lineage added + assert not lca_db._idx_to_lid # no lineage added + assert not lca_db._lid_to_lineage # no lineage added def test_api_create_insert_bad_ksize(): @@ -198,25 +198,25 @@ def test_api_create_insert_ident(): lca_db.insert(ss, ident='foo') ident = 'foo' - assert len(lca_db.ident_to_name) == 1 - assert ident in lca_db.ident_to_name - assert lca_db.ident_to_name[ident] == ss.name - assert len(lca_db.ident_to_idx) == 1 - assert lca_db.ident_to_idx[ident] == 0 - assert len(lca_db.hashval_to_idx) == len(ss.minhash) - assert len(lca_db.idx_to_ident) == 1 - assert lca_db.idx_to_ident[0] == ident + assert len(lca_db._ident_to_name) == 1 + assert ident in lca_db._ident_to_name + assert lca_db._ident_to_name[ident] == ss.name + assert len(lca_db._ident_to_idx) == 1 + assert lca_db._ident_to_idx[ident] == 0 + assert len(lca_db._hashval_to_idx) == len(ss.minhash) + assert len(lca_db._idx_to_ident) == 1 + assert lca_db._idx_to_ident[0] == ident set_of_values = set() - for vv in lca_db.hashval_to_idx.values(): + for vv in lca_db._hashval_to_idx.values(): set_of_values.update(vv) assert len(set_of_values) == 1 assert set_of_values == { 0 } - assert not lca_db.idx_to_lid # no lineage added - assert not lca_db.lid_to_lineage # no lineage added - assert not lca_db.lineage_to_lid - assert not lca_db.lid_to_idx + assert not lca_db._idx_to_lid # no lineage added + assert not lca_db._lid_to_lineage # no lineage added + assert not lca_db._lineage_to_lid + assert not lca_db._lid_to_idx def test_api_create_insert_two(): @@ -232,34 +232,34 @@ def test_api_create_insert_two(): ident = 'foo' ident2 = 'bar' - assert len(lca_db.ident_to_name) == 2 - assert ident in lca_db.ident_to_name - assert ident2 in lca_db.ident_to_name - assert lca_db.ident_to_name[ident] == ss.name - assert lca_db.ident_to_name[ident2] == ss2.name + assert len(lca_db._ident_to_name) == 2 + assert ident in lca_db._ident_to_name + assert ident2 in lca_db._ident_to_name + assert lca_db._ident_to_name[ident] == ss.name + assert lca_db._ident_to_name[ident2] == ss2.name - assert len(lca_db.ident_to_idx) == 2 - assert lca_db.ident_to_idx[ident] == 0 - assert lca_db.ident_to_idx[ident2] == 1 + assert len(lca_db._ident_to_idx) == 2 + assert lca_db._ident_to_idx[ident] == 0 + assert lca_db._ident_to_idx[ident2] == 1 combined_mins = set(ss.minhash.hashes.keys()) combined_mins.update(set(ss2.minhash.hashes.keys())) - assert len(lca_db.hashval_to_idx) == len(combined_mins) + assert len(lca_db._hashval_to_idx) == len(combined_mins) - assert len(lca_db.idx_to_ident) == 2 - assert lca_db.idx_to_ident[0] == ident - assert lca_db.idx_to_ident[1] == ident2 + assert len(lca_db._idx_to_ident) == 2 + assert lca_db._idx_to_ident[0] == ident + assert lca_db._idx_to_ident[1] == ident2 set_of_values = set() - for vv in lca_db.hashval_to_idx.values(): + for vv in lca_db._hashval_to_idx.values(): set_of_values.update(vv) assert len(set_of_values) == 2 assert set_of_values == { 0, 1 } - assert not lca_db.idx_to_lid # no lineage added - assert not lca_db.lid_to_lineage # no lineage added - assert not lca_db.lineage_to_lid - assert not lca_db.lid_to_idx + assert not lca_db._idx_to_lid # no lineage added + assert not lca_db._lid_to_lineage # no lineage added + assert not lca_db._lineage_to_lid + assert not lca_db._lid_to_idx def test_api_create_insert_w_lineage(): @@ -275,31 +275,31 @@ def test_api_create_insert_w_lineage(): # basic ident stuff ident = ss.name - assert len(lca_db.ident_to_name) == 1 - assert ident in lca_db.ident_to_name - assert lca_db.ident_to_name[ident] == ident - assert len(lca_db.ident_to_idx) == 1 - assert lca_db.ident_to_idx[ident] == 0 - assert len(lca_db.hashval_to_idx) == len(ss.minhash) - assert len(lca_db.idx_to_ident) == 1 - assert lca_db.idx_to_ident[0] == ident + assert len(lca_db._ident_to_name) == 1 + assert ident in lca_db._ident_to_name + assert lca_db._ident_to_name[ident] == ident + assert len(lca_db._ident_to_idx) == 1 + assert lca_db._ident_to_idx[ident] == 0 + assert len(lca_db._hashval_to_idx) == len(ss.minhash) + assert len(lca_db._idx_to_ident) == 1 + assert lca_db._idx_to_ident[0] == ident # all hash values added set_of_values = set() - for vv in lca_db.hashval_to_idx.values(): + for vv in lca_db._hashval_to_idx.values(): set_of_values.update(vv) assert len(set_of_values) == 1 assert set_of_values == { 0 } # check lineage stuff - assert len(lca_db.idx_to_lid) == 1 - assert lca_db.idx_to_lid[0] == 0 - assert len(lca_db.lid_to_lineage) == 1 - assert lca_db.lid_to_lineage[0] == lineage - assert lca_db.lid_to_idx[0] == { 0 } + assert len(lca_db._idx_to_lid) == 1 + assert lca_db._idx_to_lid[0] == 0 + assert len(lca_db._lid_to_lineage) == 1 + assert lca_db._lid_to_lineage[0] == lineage + assert lca_db._lid_to_idx[0] == { 0 } - assert len(lca_db.lineage_to_lid) == 1 - assert lca_db.lineage_to_lid[lineage] == 0 + assert len(lca_db._lineage_to_lid) == 1 + assert lca_db._lineage_to_lid[lineage] == 0 def test_api_create_insert_w_bad_lineage(): @@ -422,7 +422,7 @@ def test_api_create_insert_two_then_scale(): # & check... combined_mins = set(ss.minhash.hashes.keys()) combined_mins.update(set(ss2.minhash.hashes.keys())) - assert len(lca_db.hashval_to_idx) == len(combined_mins) + assert len(lca_db._hashval_to_idx) == len(combined_mins) def test_api_create_insert_scale_two(): @@ -446,7 +446,7 @@ def test_api_create_insert_scale_two(): # & check... combined_mins = set(ss.minhash.hashes.keys()) combined_mins.update(set(ss2.minhash.hashes.keys())) - assert len(lca_db.hashval_to_idx) == len(combined_mins) + assert len(lca_db._hashval_to_idx) == len(combined_mins) def test_load_single_db(): @@ -692,7 +692,7 @@ def test_db_lineage_to_lid(): dbfile = utils.get_test_data('lca/47+63.lca.json') db, ksize, scaled = lca_utils.load_single_database(dbfile) - d = db.lineage_to_lid + d = db._lineage_to_lid items = list(d.items()) items.sort() assert len(items) == 2 @@ -711,7 +711,7 @@ def test_db_lid_to_idx(): dbfile = utils.get_test_data('lca/47+63.lca.json') db, ksize, scaled = lca_utils.load_single_database(dbfile) - d = db.lid_to_idx + d = db._lid_to_idx items = list(d.items()) items.sort() assert len(items) == 2 @@ -724,7 +724,7 @@ def test_db_idx_to_ident(): dbfile = utils.get_test_data('lca/47+63.lca.json') db, ksize, scaled = lca_utils.load_single_database(dbfile) - d = db.idx_to_ident + d = db._idx_to_ident items = list(d.items()) items.sort() assert len(items) == 2 @@ -2060,6 +2060,20 @@ def test_rankinfo_with_min(runtmp): assert not lines +def test_rankinfo_with_min_2(runtmp): + db1 = utils.get_test_data('lca/dir1.lca.json') + db2 = utils.get_test_data('lca/dir2.lca.json') + + cmd = ['lca', 'rankinfo', db1, db2, '--minimum-num', '2'] + runtmp.sourmash(*cmd) + + print(cmd) + print(runtmp.last_result.out) + print(runtmp.last_result.err) + + assert "(no hashvals with lineages found)" in runtmp.last_result.err + + def test_compare_csv(runtmp): a = utils.get_test_data('lca/classify-by-both.csv') b = utils.get_test_data('lca/tara-delmont-SuppTable3.csv') @@ -2359,7 +2373,7 @@ def test_lca_db_protein_command_index(c): db_out = c.output('protein.lca.json') c.run_sourmash('lca', 'index', lineages, db_out, sigfile1, sigfile2, - '-C', '3', '--split-identifiers', '--require-taxonomy', + '-C', '2', '--split-identifiers', '--require-taxonomy', '--scaled', '100', '-k', '19', '--protein') x = sourmash.lca.lca_db.load_single_database(db_out) @@ -2468,7 +2482,7 @@ def test_lca_db_hp_command_index(c): db_out = c.output('hp.lca.json') c.run_sourmash('lca', 'index', lineages, db_out, sigfile1, sigfile2, - '-C', '3', '--split-identifiers', '--require-taxonomy', + '-C', '2', '--split-identifiers', '--require-taxonomy', '--scaled', '100', '-k', '19', '--hp') x = sourmash.lca.lca_db.load_single_database(db_out) @@ -2577,7 +2591,7 @@ def test_lca_db_dayhoff_command_index(c): db_out = c.output('dayhoff.lca.json') c.run_sourmash('lca', 'index', lineages, db_out, sigfile1, sigfile2, - '-C', '3', '--split-identifiers', '--require-taxonomy', + '-C', '2', '--split-identifiers', '--require-taxonomy', '--scaled', '100', '-k', '19', '--dayhoff') x = sourmash.lca.lca_db.load_single_database(db_out) @@ -2713,4 +2727,6 @@ def _intersect(x, y): print(sr) assert len(sr) == 2 assert sr[0].signature == ss_a + assert sr[0].score == 1.0 assert sr[1].signature == ss_c + assert sr[1].score == 0.2 diff --git a/tests/test_lca_db_protocol.py b/tests/test_lca_db_protocol.py new file mode 100644 index 0000000000..daca0f2f62 --- /dev/null +++ b/tests/test_lca_db_protocol.py @@ -0,0 +1,108 @@ +""" +Test the behavior of LCA databases. New LCA database classes should support +all of this functionality. +""" +import pytest +import sourmash_tst_utils as utils + +import sourmash +from sourmash.tax.tax_utils import MultiLineageDB +from sourmash.lca.lca_db import (LCA_Database, load_single_database) + + +def build_inmem_lca_db(runtmp): + # test in-memory LCA_Database + sigfile1 = utils.get_test_data('prot/protein/GCA_001593925.1_ASM159392v1_protein.faa.gz.sig') + sigfile2 = utils.get_test_data('prot/protein/GCA_001593935.1_ASM159393v1_protein.faa.gz.sig') + + ss1 = sourmash.load_one_signature(sigfile1) + ss2 = sourmash.load_one_signature(sigfile2) + + lineages_file = utils.get_test_data('prot/gtdb-subset-lineages.csv') + lineages = MultiLineageDB.load([lineages_file]) + + db = LCA_Database(ksize=19, scaled=100, moltype='protein') + + ident1 = ss1.name.split(' ')[0].split('.')[0] + assert lineages[ident1] + db.insert(ss1, ident=ident1, lineage=lineages[ident1]) + ident2 = ss2.name.split(' ')[0].split('.')[0] + assert lineages[ident2] + db.insert(ss2, ident=ident2, lineage=lineages[ident2]) + + return db + + +def build_json_lca_db(runtmp): + # test saved/loaded JSON database + db = build_inmem_lca_db(runtmp) + db_out = runtmp.output('protein.lca.json') + + db.save(db_out) + + x = load_single_database(db_out) + db_load = x[0] + + return db_load + + +@pytest.fixture(params=[build_inmem_lca_db, + build_json_lca_db]) +def lca_db_obj(request, runtmp): + build_fn = request.param + + return build_fn(runtmp) + + +def test_get_lineage_assignments(lca_db_obj): + # test get_lineage_assignments for a specific hash + lineages = lca_db_obj.get_lineage_assignments(178936042868009693) + + assert len(lineages) == 1 + lineage = lineages[0] + + x = [] + for tup in lineage: + if tup[0] != 'strain' or tup[1]: # ignore empty strain + x.append((tup[0], tup[1])) + + assert x == [('superkingdom', 'd__Archaea'), + ('phylum', 'p__Crenarchaeota'), + ('class', 'c__Bathyarchaeia'), + ('order', 'o__B26-1'), + ('family', 'f__B26-1'), + ('genus', 'g__B26-1'), + ('species', 's__B26-1 sp001593925'),] + + +def test_hashvals(lca_db_obj): + # test getting individual hashvals + hashvals = set(lca_db_obj.hashvals) + assert 178936042868009693 in hashvals + + +def test_get_identifiers_for_hashval(lca_db_obj): + # test getting identifiers belonging to individual hashvals + idents = lca_db_obj.get_identifiers_for_hashval(178936042868009693) + idents = list(idents) + assert len(idents) == 1 + + ident = idents[0] + assert ident == 'GCA_001593925' + + +def test_get_identifiers_for_hashval_2(lca_db_obj): + # test systematic hashval => identifiers + all_idents = set() + + for hashval in lca_db_obj.hashvals: + idents = lca_db_obj.get_identifiers_for_hashval(hashval) + #idents = list(idents) + all_idents.update(idents) + + all_idents = list(all_idents) + print(all_idents) + assert len(all_idents) == 2 + + assert 'GCA_001593925' in all_idents + assert 'GCA_001593935' in all_idents diff --git a/tests/test_manifest_protocol.py b/tests/test_manifest_protocol.py new file mode 100644 index 0000000000..bbfa7691a0 --- /dev/null +++ b/tests/test_manifest_protocol.py @@ -0,0 +1,166 @@ +""" +Tests for the 'CollectionManifest' class and protocol. All subclasses +of BaseCollectionManifest should support this functionality. +""" + +import pytest +import sourmash_tst_utils as utils + +import sourmash +from sourmash.manifest import BaseCollectionManifest, CollectionManifest + + +def build_simple_manifest(runtmp): + # load and return the manifest from prot/all.zip + filename = utils.get_test_data('prot/all.zip') + idx = sourmash.load_file_as_index(filename) + mf = idx.manifest + assert len(mf) == 8 + return mf + + +def save_load_manifest(runtmp): + # save/load the manifest from a CSV. + mf = build_simple_manifest(runtmp) + + mf_csv = runtmp.output('mf.csv') + mf.write_to_filename(mf_csv) + + load_mf = CollectionManifest.load_from_filename(mf_csv) + return load_mf + + +@pytest.fixture(params=[build_simple_manifest, + save_load_manifest]) +def manifest_obj(request, runtmp): + build_fn = request.param + + return build_fn(runtmp) + + +### +### generic CollectionManifeset tests go here +### + +def test_manifest_len(manifest_obj): + # check that 'len' works + assert len(manifest_obj) == 8 + + +def test_manifest_rows(manifest_obj): + # check that '.rows' property works + rows = list(manifest_obj.rows) + assert len(rows) == 8 + + required_keys = set(BaseCollectionManifest.required_keys) + for row in rows: + kk = set(row.keys()) + assert required_keys.issubset(kk) + + +def test_manifest_bool(manifest_obj): + # check that 'bool' works + assert bool(manifest_obj) + + +def test_make_manifest_row(manifest_obj): + # build a manifest row from a signature + sig47 = utils.get_test_data('47.fa.sig') + ss = sourmash.load_one_signature(sig47) + + row = manifest_obj.make_manifest_row(ss, 'foo', include_signature=False) + assert not 'signature' in row + assert row['internal_location'] == 'foo' + + assert row['md5'] == ss.md5sum() + assert row['md5short'] == ss.md5sum()[:8] + assert row['ksize'] == 31 + assert row['moltype'] == 'DNA' + assert row['num'] == 0 + assert row['scaled'] == 1000 + assert row['n_hashes'] == len(ss.minhash) + assert not row['with_abundance'] + assert row['name'] == ss.name + assert row['filename'] == ss.filename + + +def test_manifest_create_manifest(manifest_obj): + # test the 'create_manifest' method + sig47 = utils.get_test_data('47.fa.sig') + ss = sourmash.load_one_signature(sig47) + + def yield_sigs(): + yield ss, 'fiz' + + new_mf = manifest_obj.create_manifest(yield_sigs(), + include_signature=False) + assert len(new_mf) == 1 + new_row = list(new_mf.rows)[0] + + row = manifest_obj.make_manifest_row(ss, 'fiz', include_signature=False) + + required_keys = BaseCollectionManifest.required_keys + for k in required_keys: + assert new_row[k] == row[k], k + + +def test_manifest_select_to_manifest(manifest_obj): + # do some light testing of 'select_to_manifest' + new_mf = manifest_obj.select_to_manifest(moltype='DNA') + assert len(new_mf) == 2 + + +def test_manifest_locations(manifest_obj): + # check the 'locations' method + locs = set(['dayhoff/GCA_001593925.1_ASM159392v1_protein.faa.gz.sig', + 'dayhoff/GCA_001593935.1_ASM159393v1_protein.faa.gz.sig', + 'hp/GCA_001593925.1_ASM159392v1_protein.faa.gz.sig', + 'hp/GCA_001593935.1_ASM159393v1_protein.faa.gz.sig', + 'protein/GCA_001593925.1_ASM159392v1_protein.faa.gz.sig', + 'protein/GCA_001593935.1_ASM159393v1_protein.faa.gz.sig', + 'dna-sig.noext', + 'dna-sig.sig.gz'] + ) + assert set(manifest_obj.locations()) == locs + + +def test_manifest___contains__(manifest_obj): + # check the 'in' operator + sigfile = utils.get_test_data('prot/dayhoff/GCA_001593925.1_ASM159392v1_protein.faa.gz.sig') + ss = sourmash.load_one_signature(sigfile) + + assert ss in manifest_obj + + sigfile2 = utils.get_test_data('2.fa.sig') + ss2 = sourmash.load_one_signature(sigfile2, ksize=31) + assert ss2 not in manifest_obj + + +def test_manifest_to_picklist(manifest_obj): + # test 'to_picklist' + picklist = manifest_obj.to_picklist() + mf = manifest_obj.select_to_manifest(picklist=picklist) + + assert mf == manifest_obj + + +def test_manifest_filter_rows(manifest_obj): + # test filter_rows + filter_fn = lambda x: 'OS223' in x['name'] + + mf = manifest_obj.filter_rows(filter_fn) + + assert len(mf) == 1 + row = list(mf.rows)[0] + assert row['name'] == 'NC_011663.1 Shewanella baltica OS223, complete genome' + + +def test_manifest_filter_cols(manifest_obj): + # test filter_rows + col_filter_fn = lambda x: 'OS223' in x[0] + + mf = manifest_obj.filter_on_columns(col_filter_fn, ['name']) + + assert len(mf) == 1 + row = list(mf.rows)[0] + assert row['name'] == 'NC_011663.1 Shewanella baltica OS223, complete genome' diff --git a/tests/test_sbt.py b/tests/test_sbt.py index f31aa8c77c..9d9ba7273a 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -979,7 +979,9 @@ def _intersect(x, y): print(sr) assert len(sr) == 2 assert sr[0].signature == ss_a + assert sr[0].score == 1.0 assert sr[1].signature == ss_c + assert sr[1].score == 0.2 def test_sbt_protein_command_index(runtmp):