diff --git a/tomotwin/modules/common/findmax/findmax.py b/tomotwin/modules/common/findmax/findmax.py index 881ad60..9587c2c 100644 --- a/tomotwin/modules/common/findmax/findmax.py +++ b/tomotwin/modules/common/findmax/findmax.py @@ -103,7 +103,6 @@ def get_avg_pos(classes: List[int], regions: np.array, region_max_value: List, i return maxima_coords def find_maxima(volume: np.array, tolerance: float, global_min: float = 0.5, **kwargs) -> tuple[list, np.array]: - """ :param volume: 3D volume :param tolerance: Tolerance for detection @@ -164,7 +163,7 @@ def find_maxima(volume: np.array, tolerance: float, global_min: float = 0.5, **k if global_min == None: global_min = np.min(image) + tolerance - print("effective global min:", global_min) + # print("effective global min:", global_min) @@ -190,14 +189,8 @@ def find_maxima(volume: np.array, tolerance: float, global_min: float = 0.5, **k k = 0 region_max_value = [] working_image_raveled = working_image.ravel(order) - import tqdm - desc="Locate" - pos=None - if 'tqdm_pos' in kwargs: - desc = f"Locate class {kwargs['tqdm_pos']}" - pos = kwargs["tqdm_pos"] - - for seed_point in tqdm.tqdm(coords_sorted,position=pos, desc=desc): + + for seed_point in coords_sorted: try: iter(seed_point) except TypeError: @@ -242,7 +235,6 @@ def find_maxima(volume: np.array, tolerance: float, global_min: float = 0.5, **k chunked_arrays = np.array_split(region_list, num_cores) from concurrent.futures import ProcessPoolExecutor as Pool with Pool(multiprocessing.cpu_count()//2) as pool: - print("Call get_avg_pos") maxima_coords = pool.map(partial(get_avg_pos, regions=regions, region_max_value=region_max_value, image=image), chunked_arrays) #maxima_coords = pool.map(get_avg_pos, repeat(regions), repeat(region_max_value), repeat(image), chunked_arrays) diff --git a/tomotwin/modules/inference/findmaxima_locator.py b/tomotwin/modules/inference/findmaxima_locator.py index 2a3fd37..6ee3421 100644 --- a/tomotwin/modules/inference/findmaxima_locator.py +++ b/tomotwin/modules/inference/findmaxima_locator.py @@ -11,12 +11,16 @@ may affect the distribution and modification of this software. """ +import itertools import multiprocessing from concurrent.futures import ProcessPoolExecutor as Pool from typing import List, Tuple +import dask +import dask.array as da import numpy as np import pandas as pd +from tqdm import tqdm from tomotwin.modules.common.findmax.findmax import find_maxima from tomotwin.modules.inference.locator import Locator @@ -107,6 +111,59 @@ def maxima_to_df( dat["metric_best"] = dat["metric_best"].astype(np.float16) return dat + @staticmethod + def apply_findmax_dask(vol: np.array, + tolerance: float, + global_min: float, + **kwargs + ) -> List[Tuple]: + ''' + Applies the findmax procedure the 3d volume + :param vol: Volume where maximas needs to be detected. + :param tolerance: Prominence of the peak + :param global_min: global minimum + :param kwargs: kwargs arguments + :return: List with 3 elements. First element is the maxima position, second element is the size (region growing), third element is maxima value + ''' + + da_vol = da.from_array(vol, chunks=200) # really constant 200? + lazy_results = [] + offsets = [] + indicis = list(itertools.product(*map(range, da_vol.blocks.shape))) + with tqdm(total=len(indicis), position=kwargs.get("tqdm_pos"), + desc=f"Locate class {kwargs['tqdm_pos']}") as pbar: + + def find_max_bar_wrapper(*args, **kwargs): + r = find_maxima(*args, **kwargs) + pbar.update(1) + return r + + for inds in indicis: + chunk = da_vol.blocks[inds] + offsets.append([a * b for a, b in zip(da_vol.chunksize, inds)]) + lr = dask.delayed(find_max_bar_wrapper)(np.asarray(chunk), tolerance=tolerance, global_min=global_min, + tqdm_pos=kwargs.get("tqdm_pos"), pbar=pbar) + lazy_results.append(lr) + + # futures = dask.persist(*lazy_results) + a = dask.compute(*lazy_results) + # maximas, _ + + # apply offsets + maximas = [] + # k = 0 + for k, (maximas_in_chunk, _) in enumerate(a): + off_chunk = offsets[k] + for s_i, single_maxima in enumerate(maximas_in_chunk): + new_pos = tuple([a + b for a, b in zip(single_maxima[0], off_chunk)]) + new_entry = [new_pos] + new_entry.extend(single_maxima[1:]) + maximas_in_chunk[s_i] = new_entry + maximas.extend(maximas_in_chunk) + # k = k + 1 + + return maximas + @staticmethod def apply_findmax(vol: np.array, tolerance: float, @@ -140,13 +197,18 @@ def locate_class(class_id, global_min: float, ) -> Tuple[pd.DataFrame, np.array]: vol = FindMaximaLocator.to_volume(map_output, target_class=class_id, window_size=window_size, stride=stride) - maximas = FindMaximaLocator.apply_findmax(vol=vol, - class_id=class_id, - window_size=window_size, - stride=stride, - tolerance=tolerance, - global_min=global_min, - tqdm_pos=class_id) + maximas = FindMaximaLocator.apply_findmax_dask(vol=vol, + class_id=class_id, + window_size=window_size, + stride=stride, + tolerance=tolerance, + global_min=global_min, + tqdm_pos=class_id) + + maximas = [ + m for m in maximas if m[1] > 1 + ] # more than one pixel coordinate must be involved. + print("done", class_id) particle_df = FindMaximaLocator.maxima_to_df( maximas, class_id, stride=stride, boxsize=window_size