diff --git a/src/sourmash/commands.py b/src/sourmash/commands.py index 34eefaf8b8..02a04969fb 100644 --- a/src/sourmash/commands.py +++ b/src/sourmash/commands.py @@ -1106,7 +1106,8 @@ def prefetch(args): # iterate over signatures in db one at a time, for each db; # find those with sufficient overlap - noident_mh = copy.copy(query_mh) + noident_mh = query_mh.to_mutable() + did_a_search = False # track whether we did _any_ search at all! for dbfilename in args.databases: notify(f"loading signatures from '{dbfilename}'") @@ -1164,7 +1165,7 @@ def prefetch(args): notify(f"saved {matches_out.count} matches to CSV file '{args.output}'") csvout_fp.close() - matched_query_mh = copy.copy(query_mh) + matched_query_mh = query_mh.to_mutable() matched_query_mh.remove_many(noident_mh.hashes) notify(f"of {len(query_mh)} distinct query hashes, {len(matched_query_mh)} were found in matches above threshold.") notify(f"a total of {len(noident_mh)} query hashes remain unmatched.") diff --git a/src/sourmash/minhash.py b/src/sourmash/minhash.py index 8d45d7118e..05b8b7361e 100644 --- a/src/sourmash/minhash.py +++ b/src/sourmash/minhash.py @@ -588,7 +588,7 @@ def __add__(self, other): if self.num != other.num: raise TypeError(f"incompatible num values: self={self.num} other={other.num}") - new_obj = self.__copy__() + new_obj = self.to_mutable() new_obj += other return new_obj @@ -645,3 +645,107 @@ def moltype(self): # TODO: test in minhash tests return 'hp' else: return 'DNA' + + def to_mutable(self): + "Return a copy of this MinHash that can be changed." + return self.__copy__() + + def to_frozen(self): + "Return a frozen copy of this MinHash that cannot be changed." + new_mh = self.__copy__() + new_mh.__class__ = FrozenMinHash + return new_mh + + +class FrozenMinHash(MinHash): + def add_sequence(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def add_kmer(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def add_many(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def remove_many(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def add_hash(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def add_hash_with_abundance(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def clear(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def remove_many(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def set_abundances(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def add_protein(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def downsample(self, *, num=None, scaled=None): + if scaled and self.scaled == scaled: + return self + if num and self.num == num: + return self + + return MinHash.downsample(self, num=num, scaled=scaled).to_frozen() + + def flatten(self): + if not self.track_abundance: + return self + return MinHash.flatten(self).to_frozen() + + def __iadd__(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def merge(self, *args, **kwargs): + raise TypeError('FrozenMinHash does not support modification') + + def to_mutable(self): + "Return a copy of this MinHash that can be changed." + mut = MinHash.__new__(MinHash) + state_tup = self.__getstate__() + + # is protein/hp/dayhoff? + if state_tup[2] or state_tup[3] or state_tup[4]: + state_tup = list(state_tup) + # adjust ksize. + state_tup[1] = state_tup[1] * 3 + mut.__setstate__(state_tup) + return mut + + def to_frozen(self): + "Return a frozen copy of this MinHash that cannot be changed." + return self + + def __setstate__(self, tup): + "support pickling via __getstate__/__setstate__" + (n, ksize, is_protein, dayhoff, hp, mins, _, track_abundance, + max_hash, seed) = tup + + self.__del__() + + hash_function = ( + lib.HASH_FUNCTIONS_MURMUR64_DAYHOFF if dayhoff else + lib.HASH_FUNCTIONS_MURMUR64_HP if hp else + lib.HASH_FUNCTIONS_MURMUR64_PROTEIN if is_protein else + lib.HASH_FUNCTIONS_MURMUR64_DNA + ) + + scaled = _get_scaled_for_max_hash(max_hash) + self._objptr = lib.kmerminhash_new( + scaled, ksize, hash_function, seed, track_abundance, n + ) + if track_abundance: + MinHash.set_abundances(self, mins) + else: + MinHash.add_many(self, mins) + + def __copy__(self): + return self diff --git a/src/sourmash/search.py b/src/sourmash/search.py index d20ed5f249..93d77920ce 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -354,6 +354,7 @@ def gather_databases(query, counters, threshold_bp, ignore_abundance): # construct a new query, subtracting hashes found in previous one. new_query_mh = query.minhash.downsample(scaled=cmp_scaled) + new_query_mh = new_query_mh.to_mutable() new_query_mh.remove_many(set(found_mh.hashes)) new_query = SourmashSignature(new_query_mh) diff --git a/src/sourmash/signature.py b/src/sourmash/signature.py index e382e58311..b1915d38cf 100644 --- a/src/sourmash/signature.py +++ b/src/sourmash/signature.py @@ -9,7 +9,7 @@ from .logging import error from . import MinHash -from .minhash import to_bytes +from .minhash import to_bytes, FrozenMinHash from ._lowlevel import ffi, lib from .utils import RustObject, rustcall, decode_str @@ -42,7 +42,7 @@ def __init__(self, minhash, name="", filename=""): @property def minhash(self): - return MinHash._from_objptr( + return FrozenMinHash._from_objptr( self._methodcall(lib.signature_first_mh) ) diff --git a/tests/test_index.py b/tests/test_index.py index 7701337daa..5e30c7c585 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1326,6 +1326,7 @@ def is_found(ss, xx): def _consume_all(query_mh, counter, threshold_bp=0): results = [] + query_mh = query_mh.to_mutable() last_intersect_size = None while 1: @@ -1891,7 +1892,7 @@ def test_counter_gather_3_test_consume(): ## round 1 - cur_query = copy.copy(query_ss.minhash) + cur_query = query_ss.minhash.to_mutable() (sr, intersect_mh) = counter.peek(cur_query) assert sr.signature == match_ss_1 assert len(intersect_mh) == 10 diff --git a/tests/test__minhash.py b/tests/test_minhash.py similarity index 98% rename from tests/test__minhash.py rename to tests/test_minhash.py index dfb172687f..509731718f 100644 --- a/tests/test__minhash.py +++ b/tests/test_minhash.py @@ -42,6 +42,7 @@ import sourmash from sourmash.minhash import ( MinHash, + FrozenMinHash, hash_murmur, _get_scaled_for_max_hash, _get_max_hash_for_scaled, @@ -1908,3 +1909,39 @@ def test_max_containment_equal(): assert mh2.contained_by(mh1) == 1 assert mh1.max_containment(mh2) == 1 assert mh2.max_containment(mh1) == 1 + + +def test_frozen_and_mutable_1(track_abundance): + # mutable minhashes -> mutable minhashes creates new copy + mh1 = MinHash(0, 21, scaled=1, track_abundance=track_abundance) + mh2 = mh1.to_mutable() + + mh1.add_hash(10) + assert 10 not in mh2.hashes + + +def test_frozen_and_mutable_2(track_abundance): + # check that mutable -> frozen are separate + mh1 = MinHash(0, 21, scaled=1, track_abundance=track_abundance) + mh1.add_hash(10) + + mh2 = mh1.to_frozen() + assert 10 in mh2.hashes + mh1.add_hash(11) + assert 11 not in mh2.hashes + + +def test_frozen_and_mutable_3(track_abundance): + # check that mutable -> frozen -> mutable are all separate from each other + mh1 = MinHash(0, 21, scaled=1, track_abundance=track_abundance) + mh1.add_hash(10) + + mh2 = mh1.to_frozen() + assert 10 in mh2.hashes + mh1.add_hash(11) + assert 11 not in mh2.hashes + + mh3 = mh2.to_mutable() + mh3.add_hash(12) + assert 12 not in mh2.hashes + assert 12 not in mh1.hashes diff --git a/tests/test_prefetch.py b/tests/test_prefetch.py index d56b928a20..da37559d2b 100644 --- a/tests/test_prefetch.py +++ b/tests/test_prefetch.py @@ -295,7 +295,7 @@ def test_prefetch_nomatch_hashes(runtmp, linear_gather): ss47 = sourmash.load_one_signature(sig47, ksize=31) ss63 = sourmash.load_one_signature(sig63, ksize=31) - remain = ss47.minhash + remain = ss47.minhash.to_mutable() remain.remove_many(ss63.minhash.hashes) ss = sourmash.load_one_signature(nomatch_out) diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index 0d47987ede..82c73eb51e 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -3111,7 +3111,7 @@ def test_gather_f_match_orig(runtmp, linear_gather, prefetch_gather): print(runtmp.last_result.err) combined_sig = sourmash.load_one_signature(testdata_combined, ksize=21) - remaining_mh = copy.copy(combined_sig.minhash) + remaining_mh = combined_sig.minhash.to_mutable() def approx_equal(a, b, n=5): return round(a, n) == round(b, n)