Skip to content

Commit

Permalink
Merge branch 'latest' into fix/search_type_error
Browse files Browse the repository at this point in the history
  • Loading branch information
ctb authored Apr 14, 2022
2 parents c4e2020 + f0726c3 commit 1797943
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 68 deletions.
30 changes: 0 additions & 30 deletions .github/workflows/khmer.yml

This file was deleted.

1 change: 0 additions & 1 deletion asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"branches": ["latest"],
"dvcs": "git",
"environment_type": "virtualenv",
"pythons": ["3.10"],
"env_dir": ".asv/env",
"results_dir": ".asv/results",
"html_dir": ".asv/html",
Expand Down
72 changes: 71 additions & 1 deletion src/core/benches/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,75 @@ fn add_sequence(c: &mut Criterion) {
});
}

criterion_group!(compute, add_sequence);
fn add_sequence_protein(c: &mut Criterion) {
let mut cp = ComputeParameters::default();
cp.set_protein(true);
cp.set_dna(false);
cp.set_scaled(200);
cp.set_ksizes(vec![30]);
let template_sig = Signature::from_params(&cp);

let mut data: Vec<u8> = vec![];
let (mut f, _) = niffler::from_path("../../tests/test-data/genome-s10.fa.gz").unwrap();
let _ = f.read_to_end(&mut data);

let data = data.repeat(10);

let data_upper = data.to_ascii_uppercase();
let data_lower = data.to_ascii_lowercase();
let data_errors: Vec<u8> = data
.iter()
.enumerate()
.map(|(i, x)| if i % 89 == 1 { b'N' } else { *x })
.collect();

let mut group = c.benchmark_group("add_sequence_protein");
group.sample_size(10);

group.bench_function("valid", |b| {
b.iter(|| {
let fasta_data = Cursor::new(data_upper.clone());
let mut sig = template_sig.clone();
let mut parser = parse_fastx_reader(fasta_data).unwrap();
while let Some(rec) = parser.next() {
sig.add_protein(&rec.unwrap().seq()).unwrap();
}
});
});

group.bench_function("lowercase", |b| {
b.iter(|| {
let fasta_data = Cursor::new(data_lower.clone());
let mut sig = template_sig.clone();
let mut parser = parse_fastx_reader(fasta_data).unwrap();
while let Some(rec) = parser.next() {
sig.add_protein(&rec.unwrap().seq()).unwrap();
}
});
});

group.bench_function("invalid kmers", |b| {
b.iter(|| {
let fasta_data = Cursor::new(data_errors.clone());
let mut sig = template_sig.clone();
let mut parser = parse_fastx_reader(fasta_data).unwrap();
while let Some(rec) = parser.next() {
sig.add_protein(&rec.unwrap().seq()).unwrap();
}
});
});

group.bench_function("force with valid kmers", |b| {
b.iter(|| {
let fasta_data = Cursor::new(data_upper.clone());
let mut sig = template_sig.clone();
let mut parser = parse_fastx_reader(fasta_data).unwrap();
while let Some(rec) = parser.next() {
sig.add_protein(&rec.unwrap().seq()).unwrap();
}
});
});
}

criterion_group!(compute, add_sequence, add_sequence_protein);
criterion_main!(compute);
44 changes: 29 additions & 15 deletions src/core/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ pub struct SeqToHashes {

prot_configured: bool,
aa_seq: Vec<u8>,
translate_iter_step: usize,
}

impl SeqToHashes {
Expand Down Expand Up @@ -222,17 +223,26 @@ impl SeqToHashes {
dna_last_position_check: 0,
prot_configured: false,
aa_seq: Vec::new(),
translate_iter_step: 0,
}
}
}

/*
Iterator that return a kmer hash for all modes except translate.
In translate mode:
- all the frames are processed at once and converted to hashes.
- all the hashes are stored in `hashes_buffer`
- after processing all the kmers, `translate_iter_step` is incremented
per iteration to iterate over all the indeces of the `hashes_buffer`.
- the iterator will die once `translate_iter_step` == length(hashes_buffer)
More info https://github.com/sourmash-bio/sourmash/pull/1946
*/

impl Iterator for SeqToHashes {
type Item = Result<u64, Error>;

fn next(&mut self) -> Option<Self::Item> {
// TODO: Remove the hashes buffer
// Priority for flushing the hashes buffer

if (self.kmer_index < self.max_index) || !self.hashes_buffer.is_empty() {
// Processing DNA or Translated DNA
if !self.is_protein {
Expand Down Expand Up @@ -290,18 +300,17 @@ impl Iterator for SeqToHashes {
let hash = crate::_hash_murmur(std::cmp::min(kmer, krc), self.seed);
self.kmer_index += 1;
Some(Ok(hash))
} else if self.hashes_buffer.is_empty() {
} else if self.hashes_buffer.is_empty() && self.translate_iter_step == 0 {
// Processing protein by translating DNA
// TODO: make it a real iterator not a buffer
// TODO: Implement iterator over frames instead of hashes_buffer.

// Three frames
for i in 0..3 {
for frame_number in 0..3 {
let substr: Vec<u8> = self
.sequence
.iter()
.cloned()
.skip(i)
.take(self.sequence.len() - i)
.skip(frame_number)
.take(self.sequence.len() - frame_number)
.collect();

let aa = to_aa(
Expand All @@ -320,8 +329,8 @@ impl Iterator for SeqToHashes {
.dna_rc
.iter()
.cloned()
.skip(i)
.take(self.dna_rc.len() - i)
.skip(frame_number)
.take(self.dna_rc.len() - frame_number)
.collect();
let aa_rc = to_aa(
&rc_substr,
Expand All @@ -335,11 +344,16 @@ impl Iterator for SeqToHashes {
self.hashes_buffer.push(hash);
});
}
self.kmer_index = self.max_index;
Some(Ok(self.hashes_buffer.remove(0)))
Some(Ok(0))
} else {
let first_element: u64 = self.hashes_buffer.remove(0);
Some(Ok(first_element))
if self.translate_iter_step == self.hashes_buffer.len() {
self.hashes_buffer.clear();
self.kmer_index = self.max_index;
return Some(Ok(0));
}
let curr_idx = self.translate_iter_step;
self.translate_iter_step += 1;
Some(Ok(self.hashes_buffer[curr_idx]))
}
} else {
// Processing protein
Expand Down
38 changes: 30 additions & 8 deletions src/sourmash/command_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,15 @@ def __init__(self, args):

def __call__(self):
args = self.args
params = ComputeParameters(args.ksizes, args.seed, args.protein,
args.dayhoff, args.hp, args.dna,
args.num_hashes,
args.track_abundance, args.scaled)
params = ComputeParameters(ksizes=args.ksizes,
seed=args.seed,
protein=args.protein,
dayhoff=args.dayhoff,
hp=args.hp,
dna=args.dna,
num_hashes=args.num_hashes,
track_abundance=args.track_abundance,
scaled=args.scaled)
sig = SourmashSignature.from_params(params)
return [sig]

Expand Down Expand Up @@ -326,7 +331,17 @@ def save_sigs_to_location(siglist, save_sig):
class ComputeParameters(RustObject):
__dealloc_func__ = lib.computeparams_free

def __init__(self, ksizes, seed, protein, dayhoff, hp, dna, num_hashes, track_abundance, scaled):
def __init__(self,
*,
ksizes=(21, 31, 51),
seed=42,
protein=False,
dayhoff=False,
hp=False,
dna=True,
num_hashes=500,
track_abundance=False,
scaled=0):
self._objptr = lib.computeparams_new()

self.seed = seed
Expand Down Expand Up @@ -359,8 +374,15 @@ def from_manifest_row(cls, row):
else:
ksize = row['ksize'] * 3

p = cls([ksize], DEFAULT_MMHASH_SEED, is_protein, is_dayhoff, is_hp, is_dna,
row['num'], row['with_abundance'], row['scaled'])
p = cls(ksizes=[ksize],
seed=DEFAULT_MMHASH_SEED,
protein=is_protein,
dayhoff=is_dayhoff,
hp=is_hp,
dna=is_dna,
num_hashes=row['num'],
track_abundance=row['with_abundance'],
scaled=row['scaled'])

return p

Expand Down Expand Up @@ -405,7 +427,7 @@ def to_param_str(self):
return ",".join(pi)

def __repr__(self):
return f"ComputeParameters({self.ksizes}, {self.seed}, {self.protein}, {self.dayhoff}, {self.hp}, {self.dna}, {self.num_hashes}, {self.track_abundance}, {self.scaled})"
return f"ComputeParameters(ksizes={self.ksizes}, seed={self.seed}, protein={self.protein}, dayhoff={self.dayhoff}, hp={self.hp}, dna={self.dna}, num_hashes={self.num_hashes}, track_abundance={self.track_abundance}, scaled={self.scaled})"

def __eq__(self, other):
return (self.ksizes == other.ksizes and
Expand Down
21 changes: 11 additions & 10 deletions src/sourmash/command_sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,17 @@ def get_compute_params(self, *, split_ksizes=False):
if self.mult_ksize_by_3 and not def_dna:
ksizes = [ k*3 for k in ksizes ]

make_param = lambda ksizes: ComputeParameters(ksizes,
params_d.get('seed', def_seed),
def_protein,
def_dayhoff,
def_hp,
def_dna,
params_d.get('num', def_num),
params_d.get('track_abundance',
def_abund),
params_d.get('scaled', def_scaled))
make_param = lambda ksizes: ComputeParameters(
ksizes=ksizes,
seed=params_d.get('seed', def_seed),
protein=def_protein,
dayhoff=def_dayhoff,
hp=def_hp,
dna=def_dna,
num_hashes=params_d.get('num', def_num),
track_abundance=params_d.get('track_abundance',
def_abund),
scaled=params_d.get('scaled', def_scaled))

if split_ksizes:
for ksize in ksizes:
Expand Down
5 changes: 4 additions & 1 deletion src/sourmash/lca/lca_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,13 @@ def find(self, search_fn, query, **kwargs):
score = search_fn.score_fn(query_size, shared_size, subj_size,
total_size)

# note to self: even with JaccardSearchBestOnly, this will
# CTB note to self: even with JaccardSearchBestOnly, this will
# still iterate over & score all signatures. We should come
# up with a protocol by which the JaccardSearch object can
# signal that it is done, or something.
# For example, see test_lca_jaccard_ordering, where
# for containment we could be done early, but for Jaccard we
# cannot.
if search_fn.passes(score):
if search_fn.collect(score, subj):
if passes_all_picklists(subj, self.picklists):
Expand Down
44 changes: 44 additions & 0 deletions tests/test_lca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2670,3 +2670,47 @@ def test_lca_index_with_picklist_exclude(runtmp):
assert len(siglist) == 9
for ss in siglist:
assert 'Thermotoga' not in ss.name


def test_lca_jaccard_ordering():
# this tests a tricky situation where for three sketches A, B, C,
# |A intersect B| is greater than |A intersect C|
# _but_
# |A jaccard B| is less than |A intersect B|
a = sourmash.MinHash(ksize=31, n=0, scaled=2)
b = a.copy_and_clear()
c = a.copy_and_clear()

a.add_many([1, 2, 3, 4])
b.add_many([1, 2, 3] + list(range(10, 30)))
c.add_many([1, 5])

def _intersect(x, y):
return x.intersection_and_union_size(y)[0]

print('a intersect b:', _intersect(a, b))
print('a intersect c:', _intersect(a, c))
print('a jaccard b:', a.jaccard(b))
print('a jaccard c:', a.jaccard(c))
assert _intersect(a, b) > _intersect(a, c)
assert a.jaccard(b) < a.jaccard(c)

# thresholds to use:
assert a.jaccard(b) < 0.15
assert a.jaccard(c) > 0.15

# now - make signatures, try out :)
ss_a = sourmash.SourmashSignature(a, name='A')
ss_b = sourmash.SourmashSignature(b, name='B')
ss_c = sourmash.SourmashSignature(c, name='C')

db = sourmash.lca.LCA_Database(ksize=31, scaled=2)
db.insert(ss_a)
db.insert(ss_b)
db.insert(ss_c)

sr = db.search(ss_a, threshold=0.15)
print(sr)
assert len(sr) == 2
assert sr[0].signature == ss_a
assert sr[1].signature == ss_c
Loading

0 comments on commit 1797943

Please sign in to comment.