Skip to content

Commit

Permalink
Merge pull request #343 from creare-com/feature/async-algorithm-eval
Browse files Browse the repository at this point in the history
ENH: Adding multi-threading to algorithm eval node.
  • Loading branch information
mpu-creare authored Nov 25, 2019
2 parents 1c57404 + 050d316 commit 85716a4
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 7 deletions.
38 changes: 36 additions & 2 deletions podpac/core/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from podpac.core.node import COMMON_NODE_DOC
from podpac.core.node import node_eval
from podpac.core.utils import common_doc
from podpac.core.settings import settings
from podpac.core.managers.multi_threading import thread_manager

COMMON_DOC = COMMON_NODE_DOC.copy()

Expand Down Expand Up @@ -67,8 +69,40 @@ def eval(self, coordinates, output=None):
self._requested_coordinates = coordinates

inputs = {}
for key, node in self._inputs.items():
inputs[key] = node.eval(coordinates)

if settings["MULTITHREADING"]:
n_threads = thread_manager.request_n_threads(len(self._inputs))
if n_threads == 1:
thread_manager.release_n_threads(n_threads)
else:
n_threads = 0

if settings["MULTITHREADING"] and n_threads > 1:
# Create a function for each thread to execute asynchronously
def f(node):
return node.eval(coordinates)

# Create pool of size n_threads, note, this may be created from a sub-thread (i.e. not the main thread)
pool = thread_manager.get_thread_pool(processes=n_threads)

# Evaluate nodes in parallel/asynchronously
results = [pool.apply_async(f, [node]) for node in self._inputs.values()]

# Collect the results in dictionary
for key, res in zip(self._inputs.keys(), results):
inputs[key] = res.get()

# This prevents any more tasks from being submitted to the pool, and will close the workers one done
pool.close()

# Release these number of threads back to the thread pool
thread_manager.release_n_threads(n_threads)
self._multi_threaded = True
else:
# Evaluate nodes in serial
for key, node in self._inputs.items():
inputs[key] = node.eval(coordinates)
self._multi_threaded = False

# accumulate output coordinates
coords_list = [Coordinates.from_xarray(a.coords, crs=a.attrs.get("crs")) for a in inputs.values()]
Expand Down
2 changes: 2 additions & 0 deletions podpac/core/algorithm/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def algorithm(self, inputs):
f_locals = dict(zip(fields, res))

try:
import numexpr.evaluate # Needed for some systems to get around lazy_module issues

result = ne.evaluate(eqn, f_locals)
except (NotImplementedError, ImportError):
result = eval(eqn, f_locals)
Expand Down
90 changes: 90 additions & 0 deletions podpac/core/algorithm/test/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from collections import OrderedDict
import numpy as np

import podpac
from podpac.core.algorithm.utility import Arange
Expand Down Expand Up @@ -40,3 +41,92 @@ def test_base_definition(self):
assert "B" in d["inputs"]

# TODO value of d['inputs']['A'], etc

def test_multi_threading(self):
coords = podpac.Coordinates([[1, 2, 3]], ["lat"])
node1 = Arithmetic(A=Arange(), B=Arange(), eqn="A+B")
node2 = Arithmetic(A=node1, B=Arange(), eqn="A+B")

with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 8
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

omt = node2.eval(coords)

with podpac.settings:
podpac.settings["MULTITHREADING"] = False
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

ost = node2.eval(coords)

np.testing.assert_array_equal(omt, ost)

def test_multi_threading_cache_race(self):
coords = podpac.Coordinates([np.linspace(0, 1, 1024)], ["lat"])
with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 3
podpac.settings["CACHE_OUTPUT_DEFAULT"] = True
podpac.settings["DEFAULT_CACHE"] = ["ram"]
podpac.settings["RAM_CACHE_ENABLED"] = True
podpac.settings.set_unsafe_eval(True)
A = Arithmetic(A=Arange(), eqn="A**2")
B = Arithmetic(A=Arange(), eqn="A**2")
C = Arithmetic(A=Arange(), eqn="A**2")
D = Arithmetic(A=Arange(), eqn="A**2")
E = Arithmetic(A=Arange(), eqn="A**2")
F = Arithmetic(A=Arange(), eqn="A**2")

node2 = Arithmetic(A=A, B=B, C=C, D=D, E=E, F=F, eqn="A+B+C+D+E+F")

om = node2.eval(coords)

from_cache = [n._from_cache for n in node2.inputs.values()]

assert sum(from_cache) > 0

def test_multi_threading_stress_nthreads(self):
coords = podpac.Coordinates([np.linspace(0, 1, 4)], ["lat"])

A = Arithmetic(A=Arange(), eqn="A**0")
B = Arithmetic(A=Arange(), eqn="A**1")
C = Arithmetic(A=Arange(), eqn="A**2")
D = Arithmetic(A=Arange(), eqn="A**3")
E = Arithmetic(A=Arange(), eqn="A**4")
F = Arithmetic(A=Arange(), eqn="A**5")

node2 = Arithmetic(A=A, B=B, C=C, D=D, E=E, F=F, eqn="A+B+C+D+E+F")
node3 = Arithmetic(A=A, B=B, C=C, D=D, E=E, F=F, G=node2, eqn="A+B+C+D+E+F+G")

with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 8
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

omt = node3.eval(coords)

assert node3._multi_threaded
assert not node2._multi_threaded

with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 9 # 2 threads available after first 7
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

omt = node3.eval(coords)

assert node3._multi_threaded
assert node2._multi_threaded
22 changes: 20 additions & 2 deletions podpac/core/compositor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import division, unicode_literals, print_function, absolute_import

from multiprocessing.pool import ThreadPool
import numpy as np
import traitlets as tl

Expand All @@ -20,6 +19,7 @@
from podpac.core.data.datasource import COMMON_DATA_DOC
from podpac.core.data.interpolation import interpolation_trait
from podpac.core.utils import trait_is_defined
from podpac.core.managers.multi_threading import thread_manager

COMMON_COMPOSITOR_DOC = COMMON_DATA_DOC.copy() # superset of COMMON_NODE_DOC

Expand Down Expand Up @@ -222,24 +222,42 @@ def iteroutputs(self, coordinates):
s.set_trait("native_coordinates", nc)

if settings["MULTITHREADING"]:
n_threads = thread_manager.request_n_threads(len(src_subset))
if n_threads == 1:
thread_manager.release_n_threads(n_threads)
else:
n_threads = 0

if settings["MULTITHREADING"] and n_threads > 1:
# TODO pool of pre-allocated scratch space
# TODO: docstring?
def f(src):
return src.eval(coordinates)

pool = ThreadPool(processes=settings.get("N_THREADS", 10))
# Create pool of size n_threads, note, this may be created from a sub-thread (i.e. not the main thread)
pool = thread_manager.get_thread_pool(processes=n_threads)

# Evaluate nodes in parallel/asynchronously
results = [pool.apply_async(f, [src]) for src in src_subset]

# Yield results as they are being requested, blocking when the thread is not finished
for src, res in zip(src_subset, results):
yield res.get()
# src._output = None # free up memory

# This prevents any more tasks from being submitted to the pool, and will close the workers one done
pool.close()

# Release these number of threads back to the thread pool
thread_manager.release_n_threads(n_threads)
self._multi_threaded = True
else:
output = None # scratch space
for src in src_subset:
output = src.eval(coordinates, output)
yield output
# output[:] = np.nan
self._multi_threaded = False

@node_eval
@common_doc(COMMON_COMPOSITOR_DOC)
Expand Down
86 changes: 86 additions & 0 deletions podpac/core/managers/multi_threading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Module for dealing with multi-threaded execution.
This is used to ensure that the total number of threads specified in the settings is not exceeded.
"""

from __future__ import division, unicode_literals, print_function, absolute_import

from multiprocessing import Lock
from multiprocessing.pool import ThreadPool

from podpac.core.settings import settings

DEFAULT_N_THREADS = 10


class ThreadManager(object):
""" This is a singleton class that keeps track of the total number of threads used in an application.
"""

_lock = Lock()
cache_lock = Lock()
_n_threads_used = 0
__instance = None

def __new__(cls):
if ThreadManager.__instance is None:
ThreadManager.__instance = object.__new__(cls)
return ThreadManager.__instance

def request_n_threads(self, n):
""" Returns the number of threads allowed for a pool taking into account all other threads application, as
specified by podpac.settings["N_THREADS"].
Parameters
-----------
n : int
Number of threads requested by operation
Returns
--------
int
Number of threads a pool may use. Note, this may be less than or equal to n, and may be 0.
"""
with self._lock:
available = max(0, settings.get("N_THREADS", DEFAULT_N_THREADS) - self._n_threads_used)
claimed = min(available, n)
self._n_threads_used += claimed
return claimed

def release_n_threads(self, n):
""" This releases the number of threads specified.
Parameters
------------
n : int
Number of threads to be released
Returns
--------
int
Number of threads available after releases 'n' threads
"""
with self._lock:
self._n_threads_used = max(0, self._n_threads_used - n)
available = max(0, settings.get("N_THREADS", DEFAULT_N_THREADS) - self._n_threads_used)
return available

def get_thread_pool(self, processes):
""" Creates a threadpool that can be used to run jobs in parallel.
Parameters
-----------
processes : int
The number of threads or workers that will be part of the pool
Returns
--------
multiprocessing.ThreadPool
An instance of the ThreadPool class
"""
return ThreadPool(processes=processes)


thread_manager = ThreadManager()
11 changes: 8 additions & 3 deletions podpac/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from podpac.core.coordinates import Coordinates
from podpac.core.style import Style
from podpac.core.cache import CacheCtrl, get_default_cache_ctrl, S3CacheStore, make_cache_ctrl
from podpac.core.managers.multi_threading import thread_manager


COMMON_NODE_DOC = {
Expand Down Expand Up @@ -136,6 +137,8 @@ def _validate_units(self, d):
_requested_coordinates = tl.Instance(Coordinates, allow_none=True)
_output = tl.Instance(UnitsDataArray, allow_none=True)
_from_cache = tl.Bool(allow_none=True, default_value=None)
# Flag that is True if the Node was run multi-threaded, or None if the question doesn't apply
_multi_threaded = tl.Bool(allow_none=True, default_value=None)

def __init__(self, **kwargs):
""" Do not overwrite me """
Expand Down Expand Up @@ -478,11 +481,11 @@ def put_cache(self, data, key, coordinates=None, overwrite=False):
NodeException
Cached data already exists (and overwrite is False)
"""

if not overwrite and self.has_cache(key, coordinates=coordinates):
raise NodeException("Cached data already exists for key '%s' and coordinates %s" % (key, coordinates))

self.cache_ctrl.put(self, data, key, coordinates=coordinates, update=overwrite)
with thread_manager.cache_lock:
self.cache_ctrl.put(self, data, key, coordinates=coordinates, update=overwrite)

def has_cache(self, key, coordinates=None):
"""
Expand All @@ -500,7 +503,8 @@ def has_cache(self, key, coordinates=None):
bool
True if there is cached data for this node, key, and coordinates.
"""
return self.cache_ctrl.has(self, key, coordinates=coordinates)
with thread_manager.cache_lock:
return self.cache_ctrl.has(self, key, coordinates=coordinates)

def rem_cache(self, key, coordinates=None, mode=None):
"""
Expand Down Expand Up @@ -807,6 +811,7 @@ def wrapper(self, coordinates, output=None):
self._requested_coordinates = coordinates
key = cache_key
cache_coordinates = coordinates.transpose(*sorted(coordinates.dims)) # order agnostic caching

if self.has_cache(key, cache_coordinates) and not self.cache_update:
data = self.get_cache(key, cache_coordinates)
if output is not None:
Expand Down

0 comments on commit 85716a4

Please sign in to comment.