Skip to content

Commit

Permalink
FIX: Address review comments.
Browse files Browse the repository at this point in the history
* Made the ThreadPool creation part of the thread_manager
* Doing serial computation if N_THREADS == 1 (had to release the obtained thread)
* Added the _multi_threaded node attribute to help with testing/debugging multi-threaded execution
* Using the Lock context manager instead of 'acquire/release'
* Added test to stress the number of threads in the execution and checking to make that the correct number of threads cause a cascade to lower levels in the pipeline.
  • Loading branch information
mpu-creare committed Nov 22, 2019
1 parent 33ae85a commit 050d316
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 25 deletions.
9 changes: 6 additions & 3 deletions podpac/core/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import division, unicode_literals, print_function, absolute_import

from multiprocessing.pool import ThreadPool
from collections import OrderedDict
import inspect

Expand Down Expand Up @@ -73,16 +72,18 @@ def eval(self, coordinates, output=None):

if settings["MULTITHREADING"]:
n_threads = thread_manager.request_n_threads(len(self._inputs))
if n_threads == 1:
thread_manager.release_n_threads(n_threads)
else:
n_threads = 0

if settings["MULTITHREADING"] and n_threads > 0:
if settings["MULTITHREADING"] and n_threads > 1:
# Create a function for each thread to execute asynchronously
def f(node):
return node.eval(coordinates)

# Create pool of size n_threads, note, this may be created from a sub-thread (i.e. not the main thread)
pool = ThreadPool(processes=n_threads)
pool = thread_manager.get_thread_pool(processes=n_threads)

# Evaluate nodes in parallel/asynchronously
results = [pool.apply_async(f, [node]) for node in self._inputs.values()]
Expand All @@ -96,10 +97,12 @@ def f(node):

# Release these number of threads back to the thread pool
thread_manager.release_n_threads(n_threads)
self._multi_threaded = True
else:
# Evaluate nodes in serial
for key, node in self._inputs.items():
inputs[key] = node.eval(coordinates)
self._multi_threaded = False

# accumulate output coordinates
coords_list = [Coordinates.from_xarray(a.coords, crs=a.attrs.get("crs")) for a in inputs.values()]
Expand Down
39 changes: 39 additions & 0 deletions podpac/core/algorithm/test/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,42 @@ def test_multi_threading_cache_race(self):
from_cache = [n._from_cache for n in node2.inputs.values()]

assert sum(from_cache) > 0

def test_multi_threading_stress_nthreads(self):
coords = podpac.Coordinates([np.linspace(0, 1, 4)], ["lat"])

A = Arithmetic(A=Arange(), eqn="A**0")
B = Arithmetic(A=Arange(), eqn="A**1")
C = Arithmetic(A=Arange(), eqn="A**2")
D = Arithmetic(A=Arange(), eqn="A**3")
E = Arithmetic(A=Arange(), eqn="A**4")
F = Arithmetic(A=Arange(), eqn="A**5")

node2 = Arithmetic(A=A, B=B, C=C, D=D, E=E, F=F, eqn="A+B+C+D+E+F")
node3 = Arithmetic(A=A, B=B, C=C, D=D, E=E, F=F, G=node2, eqn="A+B+C+D+E+F+G")

with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 8
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

omt = node3.eval(coords)

assert node3._multi_threaded
assert not node2._multi_threaded

with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 9 # 2 threads available after first 7
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

omt = node3.eval(coords)

assert node3._multi_threaded
assert node2._multi_threaded
10 changes: 6 additions & 4 deletions podpac/core/compositor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import division, unicode_literals, print_function, absolute_import

from multiprocessing.pool import ThreadPool
import numpy as np
import traitlets as tl

Expand Down Expand Up @@ -224,17 +223,19 @@ def iteroutputs(self, coordinates):

if settings["MULTITHREADING"]:
n_threads = thread_manager.request_n_threads(len(src_subset))
if n_threads == 1:
thread_manager.release_n_threads(n_threads)
else:
n_threads = 0

if settings["MULTITHREADING"] and n_threads > 0:
if settings["MULTITHREADING"] and n_threads > 1:
# TODO pool of pre-allocated scratch space
# TODO: docstring?
def f(src):
return src.eval(coordinates)

# Create pool of size n_threads, note, this may be created from a sub-thread (i.e. not the main thread)
pool = ThreadPool(processes=n_threads)
pool = thread_manager.get_thread_pool(processes=n_threads)

# Evaluate nodes in parallel/asynchronously
results = [pool.apply_async(f, [src]) for src in src_subset]
Expand All @@ -249,13 +250,14 @@ def f(src):

# Release these number of threads back to the thread pool
thread_manager.release_n_threads(n_threads)

self._multi_threaded = True
else:
output = None # scratch space
for src in src_subset:
output = src.eval(coordinates, output)
yield output
# output[:] = np.nan
self._multi_threaded = False

@node_eval
@common_doc(COMMON_COMPOSITOR_DOC)
Expand Down
36 changes: 25 additions & 11 deletions podpac/core/managers/multi_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import division, unicode_literals, print_function, absolute_import

from multiprocessing import Lock
from multiprocessing.pool import ThreadPool

from podpac.core.settings import settings

Expand Down Expand Up @@ -42,12 +43,11 @@ def request_n_threads(self, n):
int
Number of threads a pool may use. Note, this may be less than or equal to n, and may be 0.
"""
self._lock.acquire()
available = max(0, settings.get("N_THREADS", DEFAULT_N_THREADS) - self._n_threads_used)
claimed = min(available, n)
self._n_threads_used += claimed
self._lock.release()
return claimed
with self._lock:
available = max(0, settings.get("N_THREADS", DEFAULT_N_THREADS) - self._n_threads_used)
claimed = min(available, n)
self._n_threads_used += claimed
return claimed

def release_n_threads(self, n):
""" This releases the number of threads specified.
Expand All @@ -62,11 +62,25 @@ def release_n_threads(self, n):
int
Number of threads available after releases 'n' threads
"""
self._lock.acquire()
self._n_threads_used = max(0, self._n_threads_used - n)
available = max(0, settings.get("N_THREADS", DEFAULT_N_THREADS) - self._n_threads_used)
self._lock.release()
return available
with self._lock:
self._n_threads_used = max(0, self._n_threads_used - n)
available = max(0, settings.get("N_THREADS", DEFAULT_N_THREADS) - self._n_threads_used)
return available

def get_thread_pool(self, processes):
""" Creates a threadpool that can be used to run jobs in parallel.
Parameters
-----------
processes : int
The number of threads or workers that will be part of the pool
Returns
--------
multiprocessing.ThreadPool
An instance of the ThreadPool class
"""
return ThreadPool(processes=processes)


thread_manager = ThreadManager()
13 changes: 6 additions & 7 deletions podpac/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def _validate_units(self, d):
_requested_coordinates = tl.Instance(Coordinates, allow_none=True)
_output = tl.Instance(UnitsDataArray, allow_none=True)
_from_cache = tl.Bool(allow_none=True, default_value=None)
# Flag that is True if the Node was run multi-threaded, or None if the question doesn't apply
_multi_threaded = tl.Bool(allow_none=True, default_value=None)

def __init__(self, **kwargs):
""" Do not overwrite me """
Expand Down Expand Up @@ -482,9 +484,8 @@ def put_cache(self, data, key, coordinates=None, overwrite=False):
if not overwrite and self.has_cache(key, coordinates=coordinates):
raise NodeException("Cached data already exists for key '%s' and coordinates %s" % (key, coordinates))

thread_manager.cache_lock.acquire()
self.cache_ctrl.put(self, data, key, coordinates=coordinates, update=overwrite)
thread_manager.cache_lock.release()
with thread_manager.cache_lock:
self.cache_ctrl.put(self, data, key, coordinates=coordinates, update=overwrite)

def has_cache(self, key, coordinates=None):
"""
Expand All @@ -502,10 +503,8 @@ def has_cache(self, key, coordinates=None):
bool
True if there is cached data for this node, key, and coordinates.
"""
thread_manager.cache_lock.acquire()
has_cache = self.cache_ctrl.has(self, key, coordinates=coordinates)
thread_manager.cache_lock.release()
return has_cache
with thread_manager.cache_lock:
return self.cache_ctrl.has(self, key, coordinates=coordinates)

def rem_cache(self, key, coordinates=None, mode=None):
"""
Expand Down

0 comments on commit 050d316

Please sign in to comment.