Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare: move get_similarities_at_index into compare_parallel function to help with memory #1509

Closed
wants to merge 9 commits into from
74 changes: 32 additions & 42 deletions src/sourmash/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,38 +92,6 @@ def similarity_args_unpack(args, ignore_abundance, downsample):
downsample=downsample)


def get_similarities_at_index(index, ignore_abundance, downsample, siglist):
"""Returns similarities of all the combinations of signature at index in
the siglist with the rest of the indices starting at index + 1. Doesn't
redundantly calculate signatures with all the other indices prior to
index - 1

:param int index: generate masks from this image
:param boolean ignore_abundance
If the sketches are not abundance weighted, or ignore_abundance=True,
compute Jaccard similarity.

If the sketches are abundance weighted, calculate the angular
similarity.
:param boolean downsample by scaled if True
:param siglist list of signatures
:return: list of similarities for the combinations of signature at index
with rest of the signatures from index+1
"""
startt = time.time()
sig_iterator = itertools.product([siglist[index]], siglist[index + 1:])
func = partial(similarity_args_unpack,
ignore_abundance=ignore_abundance,
downsample=downsample)
similarity_list = list(map(func, sig_iterator))
notify(
"comparison for index {} done in {:.5f} seconds",
index,
time.time() - startt,
end='\r')
return similarity_list


def compare_parallel(siglist, ignore_abundance, downsample, n_jobs):
"""Compare all combinations of signatures and return a matrix
of similarities. Processes combinations parallely on number of processes
Expand Down Expand Up @@ -159,15 +127,37 @@ def compare_parallel(siglist, ignore_abundance, downsample, n_jobs):
memmap_similarities, filename = to_memmap(similarities)
notify("Initialized memmapped similarities matrix")

# Initialize the function using func.partial with the common arguments like
# siglist, ignore_abundance, downsample, for computing all the signatures
# The only changing parameter that will be mapped from the pool is the index
func = partial(
get_similarities_at_index,
siglist=siglist,
ignore_abundance=ignore_abundance,
downsample=downsample)
notify("Created similarity func")
# To avoid sharing siglist via pickle declaring a function inside function
global get_similarities_at_index
def get_similarities_at_index(index):
"""Returns similarities of all the combinations of signature at index in
the siglist with the rest of the indices starting at index + 1. Doesn't
redundantly calculate signatures with all the other indices prior to
index - 1

:param int index: generate masks from this image
:param boolean ignore_abundance
If the sketches are not abundance weighted, or ignore_abundance=True,
compute Jaccard similarity.

If the sketches are abundance weighted, calculate the angular
similarity.
:param boolean downsample by scaled if True
:return: list of similarities for the combinations of signature at index
with rest of the signatures from index+1
"""
startt = time.time()
sig_iterator = itertools.product([siglist[index]], siglist[index + 1:])
func = partial(similarity_args_unpack,
ignore_abundance=ignore_abundance,
downsample=downsample)
similarity_list = list(map(func, sig_iterator))
notify(
"comparison for index {} done in {:.5f} seconds",
index,
time.time() - startt,
end='\r')
return similarity_list

# Initialize multiprocess.pool
pool = multiprocessing.Pool(processes=n_jobs)
Expand All @@ -179,7 +169,7 @@ def compare_parallel(siglist, ignore_abundance, downsample, n_jobs):
notify("Calculated chunk size for multiprocessing")

# This will not generate the results yet, since pool.imap returns a generator
result = pool.imap(func, range(length_siglist), chunksize=chunksize)
result = pool.imap(get_similarities_at_index, range(length_siglist), chunksize=chunksize)
notify("Initialized multiprocessing pool.imap")

# Enumerate and calculate similarities at each of the indices
Expand Down