Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/rapids-23-12' into rapids-23-12
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenwagner committed Feb 23, 2024
2 parents 6a7678f + 65fd256 commit 01090e5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 18 deletions.
14 changes: 3 additions & 11 deletions tomotwin/modules/common/findmax/findmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def get_avg_pos(classes: List[int], regions: np.array, region_max_value: List, i
return maxima_coords

def find_maxima(volume: np.array, tolerance: float, global_min: float = 0.5, **kwargs) -> tuple[list, np.array]:

"""
:param volume: 3D volume
:param tolerance: Tolerance for detection
Expand Down Expand Up @@ -164,7 +163,7 @@ def find_maxima(volume: np.array, tolerance: float, global_min: float = 0.5, **k
if global_min == None:
global_min = np.min(image) + tolerance

print("effective global min:", global_min)
# print("effective global min:", global_min)



Expand All @@ -190,14 +189,8 @@ def find_maxima(volume: np.array, tolerance: float, global_min: float = 0.5, **k
k = 0
region_max_value = []
working_image_raveled = working_image.ravel(order)
import tqdm
desc="Locate"
pos=None
if 'tqdm_pos' in kwargs:
desc = f"Locate class {kwargs['tqdm_pos']}"
pos = kwargs["tqdm_pos"]

for seed_point in tqdm.tqdm(coords_sorted,position=pos, desc=desc):

for seed_point in coords_sorted:
try:
iter(seed_point)
except TypeError:
Expand Down Expand Up @@ -242,7 +235,6 @@ def find_maxima(volume: np.array, tolerance: float, global_min: float = 0.5, **k
chunked_arrays = np.array_split(region_list, num_cores)
from concurrent.futures import ProcessPoolExecutor as Pool
with Pool(multiprocessing.cpu_count()//2) as pool:
print("Call get_avg_pos")
maxima_coords = pool.map(partial(get_avg_pos, regions=regions, region_max_value=region_max_value, image=image),
chunked_arrays)
#maxima_coords = pool.map(get_avg_pos, repeat(regions), repeat(region_max_value), repeat(image), chunked_arrays)
Expand Down
76 changes: 69 additions & 7 deletions tomotwin/modules/inference/findmaxima_locator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
may affect the distribution and modification of this software.
"""

import itertools
import multiprocessing
from concurrent.futures import ProcessPoolExecutor as Pool
from typing import List, Tuple

import dask
import dask.array as da
import numpy as np
import pandas as pd
from tqdm import tqdm

from tomotwin.modules.common.findmax.findmax import find_maxima
from tomotwin.modules.inference.locator import Locator
Expand Down Expand Up @@ -107,6 +111,59 @@ def maxima_to_df(
dat["metric_best"] = dat["metric_best"].astype(np.float16)
return dat

@staticmethod
def apply_findmax_dask(vol: np.array,
tolerance: float,
global_min: float,
**kwargs
) -> List[Tuple]:
'''
Applies the findmax procedure the 3d volume
:param vol: Volume where maximas needs to be detected.
:param tolerance: Prominence of the peak
:param global_min: global minimum
:param kwargs: kwargs arguments
:return: List with 3 elements. First element is the maxima position, second element is the size (region growing), third element is maxima value
'''

da_vol = da.from_array(vol, chunks=200) # really constant 200?
lazy_results = []
offsets = []
indicis = list(itertools.product(*map(range, da_vol.blocks.shape)))
with tqdm(total=len(indicis), position=kwargs.get("tqdm_pos"),
desc=f"Locate class {kwargs['tqdm_pos']}") as pbar:

def find_max_bar_wrapper(*args, **kwargs):
r = find_maxima(*args, **kwargs)
pbar.update(1)
return r

Check warning on line 139 in tomotwin/modules/inference/findmaxima_locator.py

View check run for this annotation

Codecov / codecov/patch

tomotwin/modules/inference/findmaxima_locator.py#L137-L139

Added lines #L137 - L139 were not covered by tests

for inds in indicis:
chunk = da_vol.blocks[inds]
offsets.append([a * b for a, b in zip(da_vol.chunksize, inds)])
lr = dask.delayed(find_max_bar_wrapper)(np.asarray(chunk), tolerance=tolerance, global_min=global_min,
tqdm_pos=kwargs.get("tqdm_pos"), pbar=pbar)
lazy_results.append(lr)

# futures = dask.persist(*lazy_results)
a = dask.compute(*lazy_results)
# maximas, _

# apply offsets
maximas = []
# k = 0
for k, (maximas_in_chunk, _) in enumerate(a):
off_chunk = offsets[k]
for s_i, single_maxima in enumerate(maximas_in_chunk):
new_pos = tuple([a + b for a, b in zip(single_maxima[0], off_chunk)])
new_entry = [new_pos]
new_entry.extend(single_maxima[1:])
maximas_in_chunk[s_i] = new_entry
maximas.extend(maximas_in_chunk)
# k = k + 1

return maximas

@staticmethod
def apply_findmax(vol: np.array,
tolerance: float,
Expand Down Expand Up @@ -140,13 +197,18 @@ def locate_class(class_id,
global_min: float,
) -> Tuple[pd.DataFrame, np.array]:
vol = FindMaximaLocator.to_volume(map_output, target_class=class_id, window_size=window_size, stride=stride)
maximas = FindMaximaLocator.apply_findmax(vol=vol,
class_id=class_id,
window_size=window_size,
stride=stride,
tolerance=tolerance,
global_min=global_min,
tqdm_pos=class_id)
maximas = FindMaximaLocator.apply_findmax_dask(vol=vol,
class_id=class_id,
window_size=window_size,
stride=stride,
tolerance=tolerance,
global_min=global_min,
tqdm_pos=class_id)

maximas = [
m for m in maximas if m[1] > 1
] # more than one pixel coordinate must be involved.

print("done", class_id)
particle_df = FindMaximaLocator.maxima_to_df(
maximas, class_id, stride=stride, boxsize=window_size
Expand Down

0 comments on commit 01090e5

Please sign in to comment.