diff --git a/src/sourmash/command_compute.py b/src/sourmash/command_compute.py index aac66def1..dbb3c42ad 100644 --- a/src/sourmash/command_compute.py +++ b/src/sourmash/command_compute.py @@ -13,9 +13,15 @@ from .utils import RustObject from ._lowlevel import ffi, lib -DEFAULT_COMPUTE_K = "21,31,51" -DEFAULT_MMHASH_SEED = 42 -DEFAULT_LINE_COUNT = 1500 + +from .command_sketch import ( + _compute_individual, + _compute_merged, + ComputeParameters, + add_seq, + set_sig_name, + DEFAULT_MMHASH_SEED, +) def compute(args): diff --git a/src/sourmash/command_sketch.py b/src/sourmash/command_sketch.py index 508cac7c0..e98212f8c 100644 --- a/src/sourmash/command_sketch.py +++ b/src/sourmash/command_sketch.py @@ -12,18 +12,14 @@ import sourmash from .signature import SourmashSignature from .logging import notify, error, set_quiet, print_results -from .command_compute import ( - _compute_individual, - _compute_merged, - ComputeParameters, - add_seq, - set_sig_name, - DEFAULT_MMHASH_SEED, -) from sourmash import sourmash_args from sourmash.sourmash_args import check_scaled_bounds, check_num_bounds from sourmash.sig.__main__ import _summarize_manifest, _SketchInfo from sourmash.manifest import CollectionManifest +from .utils import RustObject +from ._lowlevel import ffi, lib + +DEFAULT_MMHASH_SEED = 42 DEFAULTS = dict( dna="k=31,scaled=1000,noabund", @@ -637,3 +633,452 @@ def fromfile(args): notify( f"** {total_sigs} total requested; output {total_sigs - skipped_sigs}, skipped {skipped_sigs}" ) + + +class _signatures_for_compute_factory: + "Build signatures on demand, based on args input to 'compute'." + + def __init__(self, args): + self.args = args + + def __call__(self): + args = self.args + params = ComputeParameters( + ksizes=args.ksizes, + seed=args.seed, + protein=args.protein, + dayhoff=args.dayhoff, + hp=args.hp, + dna=args.dna, + num_hashes=args.num_hashes, + track_abundance=args.track_abundance, + scaled=args.scaled, + ) + sig = SourmashSignature.from_params(params) + return [sig] + + +def _compute_individual(args, signatures_factory): + # this is where output signatures will go. + save_sigs = None + + # track: is this the first file? in cases where we have empty inputs, + # we don't want to open any outputs. + first_file_for_output = True + + # if args.output is set, we are aggregating all output to a single file. + # do not open a new output file for each input. + open_output_each_time = True + if args.output: + open_output_each_time = False + + for filename in args.filenames: + if open_output_each_time: + # for each input file, construct output filename + sigfile = os.path.basename(filename) + ".sig" + if args.output_dir: + sigfile = os.path.join(args.output_dir, sigfile) + + # does it already exist? skip if so. + if os.path.exists(sigfile) and not args.force: + notify("skipping {} - already done", filename) + continue # go on to next file. + + # nope? ok, let's save to it. + assert not save_sigs + save_sigs = sourmash_args.SaveSignaturesToLocation(sigfile) + + # + # calculate signatures! + # + + # now, set up to iterate over sequences. + with screed.open(filename) as screed_iter: + if not screed_iter: + notify(f"no sequences found in '{filename}'?!") + continue + + # open output for signatures + if open_output_each_time: + save_sigs.open() + # or... is this the first time to write something to args.output? + elif first_file_for_output: + save_sigs = sourmash_args.SaveSignaturesToLocation(args.output) + save_sigs.open() + first_file_for_output = False + + # make a new signature for each sequence? + if args.singleton: + n_calculated = 0 + for n, record in enumerate(screed_iter): + sigs = signatures_factory() + try: + add_seq( + sigs, + record.sequence, + args.input_is_protein, + args.check_sequence, + ) + except ValueError as exc: + error(f"ERROR when reading from '{filename}' - ") + error(str(exc)) + sys.exit(-1) + + n_calculated += len(sigs) + set_sig_name(sigs, filename, name=record.name) + save_sigs_to_location(sigs, save_sigs) + + notify( + "calculated {} signatures for {} sequences in {}", + n_calculated, + n + 1, + filename, + ) + + # nope; make a single sig for the whole file + else: + sigs = signatures_factory() + + # consume & calculate signatures + notify(f"... reading sequences from {filename}") + name = None + for n, record in enumerate(screed_iter): + if n % 10000 == 0: + if n: + notify("\r...{} {}", filename, n, end="") + elif args.name_from_first: + name = record.name + + try: + add_seq( + sigs, + record.sequence, + args.input_is_protein, + args.check_sequence, + ) + except ValueError as exc: + error(f"ERROR when reading from '{filename}' - ") + error(str(exc)) + sys.exit(-1) + + notify("...{} {} sequences", filename, n, end="") + + set_sig_name(sigs, filename, name) + save_sigs_to_location(sigs, save_sigs) + + notify( + f"calculated {len(sigs)} signatures for {n+1} sequences in {filename}" + ) + + # if not args.output, close output for every input filename. + if open_output_each_time: + save_sigs.close() + notify( + f"saved {len(save_sigs)} signature(s) to '{save_sigs.location}'. Note: signature license is CC0." + ) + save_sigs = None + + # if --output-dir specified, all collected signatures => args.output, + # and we need to close here. + if args.output and save_sigs is not None: + save_sigs.close() + notify( + f"saved {len(save_sigs)} signature(s) to '{save_sigs.location}'. Note: signature license is CC0." + ) + + +def _compute_merged(args, signatures_factory): + # make a signature for the whole file + sigs = signatures_factory() + + total_seq = 0 + for filename in args.filenames: + # consume & calculate signatures + notify("... reading sequences from {}", filename) + + n = None + with screed.open(filename) as f: + for n, record in enumerate(f): + if n % 10000 == 0 and n: + notify("\r... {} {}", filename, n, end="") + + add_seq( + sigs, record.sequence, args.input_is_protein, args.check_sequence + ) + if n is not None: + notify("... {} {} sequences", filename, n + 1) + total_seq += n + 1 + else: + notify(f"no sequences found in '{filename}'?!") + + if total_seq: + set_sig_name(sigs, filename, name=args.merge) + notify( + "calculated 1 signature for {} sequences taken from {} files", + total_seq, + len(args.filenames), + ) + + # at end, save! + save_siglist(sigs, args.output) + + +def add_seq(sigs, seq, input_is_protein, check_sequence): + for sig in sigs: + if input_is_protein: + sig.add_protein(seq) + else: + sig.add_sequence(seq, not check_sequence) + + +def set_sig_name(sigs, filename, name=None): + if filename == "-": # if stdin, set filename to empty. + filename = "" + for sig in sigs: + if name is not None: + sig._name = name + + sig.filename = filename + + +def save_siglist(siglist, sigfile_name): + "Save multiple signatures to a filename." + + # save! + with sourmash_args.SaveSignaturesToLocation(sigfile_name) as save_sig: + for ss in siglist: + save_sig.add(ss) + + notify(f"saved {len(save_sig)} signature(s) to '{save_sig.location}'") + + +def save_sigs_to_location(siglist, save_sig): + "Save multiple signatures to an already-open location." + import sourmash + + for ss in siglist: + save_sig.add(ss) + + +class ComputeParameters(RustObject): + __dealloc_func__ = lib.computeparams_free + + def __init__( + self, + *, + ksizes=(21, 31, 51), + seed=42, + protein=False, + dayhoff=False, + hp=False, + dna=True, + num_hashes=500, + track_abundance=False, + scaled=0, + ): + self._objptr = lib.computeparams_new() + + self.seed = seed + self.ksizes = ksizes + self.protein = protein + self.dayhoff = dayhoff + self.hp = hp + self.dna = dna + self.num_hashes = num_hashes + self.track_abundance = track_abundance + self.scaled = scaled + + @classmethod + def from_manifest_row(cls, row): + "convert a CollectionManifest row into a ComputeParameters object" + is_dna = is_protein = is_dayhoff = is_hp = False + if row["moltype"] == "DNA": + is_dna = True + elif row["moltype"] == "protein": + is_protein = True + elif row["moltype"] == "hp": + is_hp = True + elif row["moltype"] == "dayhoff": + is_dayhoff = True + else: + assert 0 + + if is_dna: + ksize = row["ksize"] + else: + ksize = row["ksize"] * 3 + + p = cls( + ksizes=[ksize], + seed=DEFAULT_MMHASH_SEED, + protein=is_protein, + dayhoff=is_dayhoff, + hp=is_hp, + dna=is_dna, + num_hashes=row["num"], + track_abundance=row["with_abundance"], + scaled=row["scaled"], + ) + + return p + + def to_param_str(self): + "Convert object to equivalent params str." + pi = [] + + if self.dna: + pi.append("dna") + elif self.protein: + pi.append("protein") + elif self.hp: + pi.append("hp") + elif self.dayhoff: + pi.append("dayhoff") + else: + assert 0 # must be one of the previous + + if self.dna: + kstr = [f"k={k}" for k in self.ksizes] + else: + # for protein, divide ksize by three. + kstr = [f"k={k//3}" for k in self.ksizes] + assert kstr + pi.extend(kstr) + + if self.num_hashes != 0: + pi.append(f"num={self.num_hashes}") + elif self.scaled != 0: + pi.append(f"scaled={self.scaled}") + else: + assert 0 + + if self.track_abundance: + pi.append("abund") + # noabund is default + + if self.seed != DEFAULT_MMHASH_SEED: + pi.append(f"seed={self.seed}") + # self.seed + + return ",".join(pi) + + def __repr__(self): + return f"ComputeParameters(ksizes={self.ksizes}, seed={self.seed}, protein={self.protein}, dayhoff={self.dayhoff}, hp={self.hp}, dna={self.dna}, num_hashes={self.num_hashes}, track_abundance={self.track_abundance}, scaled={self.scaled})" + + def __eq__(self, other): + return ( + self.ksizes == other.ksizes + and self.seed == other.seed + and self.protein == other.protein + and self.dayhoff == other.dayhoff + and self.hp == other.hp + and self.dna == other.dna + and self.num_hashes == other.num_hashes + and self.track_abundance == other.track_abundance + and self.scaled == other.scaled + ) + + @staticmethod + def from_args(args): + ptr = lib.computeparams_new() + ret = ComputeParameters._from_objptr(ptr) + + for arg, value in vars(args).items(): + try: + getattr(type(ret), arg).fset(ret, value) + except AttributeError: + pass + + return ret + + @property + def seed(self): + return self._methodcall(lib.computeparams_seed) + + @seed.setter + def seed(self, v): + return self._methodcall(lib.computeparams_set_seed, v) + + @property + def ksizes(self): + size = ffi.new("uintptr_t *") + ksizes_ptr = self._methodcall(lib.computeparams_ksizes, size) + size = size[0] + ksizes = ffi.unpack(ksizes_ptr, size) + lib.computeparams_ksizes_free(ksizes_ptr, size) + return ksizes + + @ksizes.setter + def ksizes(self, v): + return self._methodcall(lib.computeparams_set_ksizes, list(v), len(v)) + + @property + def protein(self): + return self._methodcall(lib.computeparams_protein) + + @protein.setter + def protein(self, v): + return self._methodcall(lib.computeparams_set_protein, v) + + @property + def dayhoff(self): + return self._methodcall(lib.computeparams_dayhoff) + + @dayhoff.setter + def dayhoff(self, v): + return self._methodcall(lib.computeparams_set_dayhoff, v) + + @property + def hp(self): + return self._methodcall(lib.computeparams_hp) + + @hp.setter + def hp(self, v): + return self._methodcall(lib.computeparams_set_hp, v) + + @property + def dna(self): + return self._methodcall(lib.computeparams_dna) + + @dna.setter + def dna(self, v): + return self._methodcall(lib.computeparams_set_dna, v) + + @property + def moltype(self): + if self.dna: + moltype = "DNA" + elif self.protein: + moltype = "protein" + elif self.hp: + moltype = "hp" + elif self.dayhoff: + moltype = "dayhoff" + else: + assert 0 + + return moltype + + @property + def num_hashes(self): + return self._methodcall(lib.computeparams_num_hashes) + + @num_hashes.setter + def num_hashes(self, v): + return self._methodcall(lib.computeparams_set_num_hashes, v) + + @property + def track_abundance(self): + return self._methodcall(lib.computeparams_track_abundance) + + @track_abundance.setter + def track_abundance(self, v): + return self._methodcall(lib.computeparams_set_track_abundance, v) + + @property + def scaled(self): + return self._methodcall(lib.computeparams_scaled) + + @scaled.setter + def scaled(self, v): + return self._methodcall(lib.computeparams_set_scaled, int(v)) diff --git a/tests/test_cmd_signature.py b/tests/test_cmd_signature.py index 9f14b6df5..8dfe8dc74 100644 --- a/tests/test_cmd_signature.py +++ b/tests/test_cmd_signature.py @@ -3355,16 +3355,35 @@ def test_sig_describe_dayhoff(c): ) -@utils.in_tempdir -def test_sig_describe_1_hp(c): +def test_sig_describe_1_hp(runtmp): + c = runtmp + # get basic info on a signature testdata = utils.get_test_data("short.fa") - c.run_sourmash( - "compute", "-k", "21,30", "--dayhoff", "--hp", "--protein", "--dna", testdata + + # run four separate commands to make 4 different sets of sigs... + c.sourmash("sketch", "dna", "-p", "k=21,k=30,num=500", "-o", "out.zip", testdata) + c.sourmash( + "sketch", "translate", "-p", "k=7,k=10,num=500", "-o", "out.zip", testdata + ) + c.sourmash( + "sketch", "translate", "-p", "k=7,k=10,num=500,hp", "-o", "out.zip", testdata + ) + c.sourmash( + "sketch", + "translate", + "-p", + "k=7,k=10,num=500,dayhoff", + "-o", + "out.zip", + testdata, ) - # stdout should be new signature - computed_sig = os.path.join(c.location, "short.fa.sig") - c.run_sourmash("sig", "describe", computed_sig) + + # then combine into one .sig file + c.sourmash("sig", "cat", "out.zip", "-o", "short.fa.sig") + + # & run sig describe + c.run_sourmash("sig", "describe", "short.fa.sig") out = c.last_result.out print(c.last_result.out) @@ -3444,7 +3463,6 @@ def test_sig_describe_1_hp(c): signature license: CC0 --- -signature filename: short.fa.sig signature: ** no name ** source file: short.fa md5: 71f7c111c01785e5f38efad45b00a0e1 diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index fc083a21e..23647e517 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -1791,26 +1791,48 @@ def test_compare_deduce_molecule(runtmp): def test_compare_choose_molecule_dna(runtmp): - # choose molecule type + # choose molecule type with --dna, ignoring protein testdata1 = utils.get_test_data("short.fa") testdata2 = utils.get_test_data("short2.fa") - runtmp.sourmash("compute", "-k", "30", "--dna", "--protein", testdata1, testdata2) - - runtmp.sourmash("compare", "--dna", "short.fa.sig", "short2.fa.sig") + runtmp.sourmash( + "sketch", "dna", "-p", "k=30,num=500", testdata1, testdata2, "-o", "sigs.zip" + ) + runtmp.sourmash( + "sketch", + "translate", + "-p", + "k=10,num=500", + testdata1, + testdata2, + "-o", + "sigs.zip", + ) + runtmp.sourmash("compare", "--dna", "sigs.zip") print(runtmp.last_result.status, runtmp.last_result.out, runtmp.last_result.err) assert "min similarity in matrix: 0.938" in runtmp.last_result.out def test_compare_choose_molecule_protein(runtmp): - # choose molecule type + # choose molecule type with --protein, ignoring DNA testdata1 = utils.get_test_data("short.fa") testdata2 = utils.get_test_data("short2.fa") - runtmp.sourmash("compute", "-k", "30", "--dna", "--protein", testdata1, testdata2) - - runtmp.sourmash("compare", "--protein", "short.fa.sig", "short2.fa.sig") + runtmp.sourmash( + "sketch", "dna", "-p", "k=30,num=500", testdata1, testdata2, "-o", "sigs.zip" + ) + runtmp.sourmash( + "sketch", + "translate", + "-p", + "k=10,num=500", + testdata1, + testdata2, + "-o", + "sigs.zip", + ) + runtmp.sourmash("compare", "--protein", "sigs.zip") print(runtmp.last_result.status, runtmp.last_result.out, runtmp.last_result.err) assert "min similarity in matrix: 0.91" in runtmp.last_result.out diff --git a/tests/test_sourmash_sketch.py b/tests/test_sourmash_sketch.py index 87460dcbc..98448e4d6 100644 --- a/tests/test_sourmash_sketch.py +++ b/tests/test_sourmash_sketch.py @@ -15,8 +15,7 @@ from sourmash import MinHash from sourmash.sbt import SBT, Node from sourmash.sbtmh import SigLeaf, load_sbt_index -from sourmash.command_compute import ComputeParameters -from sourmash.cli.compute import subparser +from sourmash.command_sketch import ComputeParameters from sourmash.cli import SourmashParser from sourmash import manifest