Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Adding multi-threading to algorithm eval node. #343

Merged
merged 8 commits into from
Nov 25, 2019
38 changes: 36 additions & 2 deletions podpac/core/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from podpac.core.node import COMMON_NODE_DOC
from podpac.core.node import node_eval
from podpac.core.utils import common_doc
from podpac.core.settings import settings
from podpac.core.managers.multi_threading import thread_manager

COMMON_DOC = COMMON_NODE_DOC.copy()

Expand Down Expand Up @@ -67,8 +69,40 @@ def eval(self, coordinates, output=None):
self._requested_coordinates = coordinates

inputs = {}
for key, node in self._inputs.items():
inputs[key] = node.eval(coordinates)

if settings["MULTITHREADING"]:
n_threads = thread_manager.request_n_threads(len(self._inputs))
if n_threads == 1:
thread_manager.release_n_threads(n_threads)
else:
n_threads = 0

if settings["MULTITHREADING"] and n_threads > 1:
# Create a function for each thread to execute asynchronously
def f(node):
return node.eval(coordinates)

# Create pool of size n_threads, note, this may be created from a sub-thread (i.e. not the main thread)
pool = thread_manager.get_thread_pool(processes=n_threads)

# Evaluate nodes in parallel/asynchronously
results = [pool.apply_async(f, [node]) for node in self._inputs.values()]

# Collect the results in dictionary
for key, res in zip(self._inputs.keys(), results):
inputs[key] = res.get()

# This prevents any more tasks from being submitted to the pool, and will close the workers one done
pool.close()

# Release these number of threads back to the thread pool
thread_manager.release_n_threads(n_threads)
self._multi_threaded = True
else:
# Evaluate nodes in serial
for key, node in self._inputs.items():
inputs[key] = node.eval(coordinates)
self._multi_threaded = False

# accumulate output coordinates
coords_list = [Coordinates.from_xarray(a.coords, crs=a.attrs.get("crs")) for a in inputs.values()]
Expand Down
2 changes: 2 additions & 0 deletions podpac/core/algorithm/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def algorithm(self, inputs):
f_locals = dict(zip(fields, res))

try:
import numexpr.evaluate # Needed for some systems to get around lazy_module issues

result = ne.evaluate(eqn, f_locals)
except (NotImplementedError, ImportError):
result = eval(eqn, f_locals)
Expand Down
90 changes: 90 additions & 0 deletions podpac/core/algorithm/test/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from collections import OrderedDict
import numpy as np

import podpac
from podpac.core.algorithm.utility import Arange
Expand Down Expand Up @@ -40,3 +41,92 @@ def test_base_definition(self):
assert "B" in d["inputs"]

# TODO value of d['inputs']['A'], etc

def test_multi_threading(self):
coords = podpac.Coordinates([[1, 2, 3]], ["lat"])
node1 = Arithmetic(A=Arange(), B=Arange(), eqn="A+B")
node2 = Arithmetic(A=node1, B=Arange(), eqn="A+B")

with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 8
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

omt = node2.eval(coords)

with podpac.settings:
podpac.settings["MULTITHREADING"] = False
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

ost = node2.eval(coords)

np.testing.assert_array_equal(omt, ost)
mpu-creare marked this conversation as resolved.
Show resolved Hide resolved

def test_multi_threading_cache_race(self):
jmilloy marked this conversation as resolved.
Show resolved Hide resolved
coords = podpac.Coordinates([np.linspace(0, 1, 1024)], ["lat"])
with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 3
podpac.settings["CACHE_OUTPUT_DEFAULT"] = True
podpac.settings["DEFAULT_CACHE"] = ["ram"]
podpac.settings["RAM_CACHE_ENABLED"] = True
podpac.settings.set_unsafe_eval(True)
A = Arithmetic(A=Arange(), eqn="A**2")
B = Arithmetic(A=Arange(), eqn="A**2")
C = Arithmetic(A=Arange(), eqn="A**2")
D = Arithmetic(A=Arange(), eqn="A**2")
E = Arithmetic(A=Arange(), eqn="A**2")
F = Arithmetic(A=Arange(), eqn="A**2")

node2 = Arithmetic(A=A, B=B, C=C, D=D, E=E, F=F, eqn="A+B+C+D+E+F")

om = node2.eval(coords)

from_cache = [n._from_cache for n in node2.inputs.values()]

assert sum(from_cache) > 0

def test_multi_threading_stress_nthreads(self):
coords = podpac.Coordinates([np.linspace(0, 1, 4)], ["lat"])

A = Arithmetic(A=Arange(), eqn="A**0")
B = Arithmetic(A=Arange(), eqn="A**1")
C = Arithmetic(A=Arange(), eqn="A**2")
D = Arithmetic(A=Arange(), eqn="A**3")
E = Arithmetic(A=Arange(), eqn="A**4")
F = Arithmetic(A=Arange(), eqn="A**5")

node2 = Arithmetic(A=A, B=B, C=C, D=D, E=E, F=F, eqn="A+B+C+D+E+F")
node3 = Arithmetic(A=A, B=B, C=C, D=D, E=E, F=F, G=node2, eqn="A+B+C+D+E+F+G")

with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 8
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

omt = node3.eval(coords)

assert node3._multi_threaded
assert not node2._multi_threaded

with podpac.settings:
podpac.settings["MULTITHREADING"] = True
podpac.settings["N_THREADS"] = 9 # 2 threads available after first 7
podpac.settings["CACHE_OUTPUT_DEFAULT"] = False
podpac.settings["DEFAULT_CACHE"] = []
podpac.settings["RAM_CACHE_ENABLED"] = False
podpac.settings.set_unsafe_eval(True)

omt = node3.eval(coords)

assert node3._multi_threaded
assert node2._multi_threaded
24 changes: 21 additions & 3 deletions podpac/core/compositor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import division, unicode_literals, print_function, absolute_import

from multiprocessing.pool import ThreadPool
import numpy as np
import traitlets as tl

Expand All @@ -20,6 +19,7 @@
from podpac.core.data.datasource import COMMON_DATA_DOC
from podpac.core.data.interpolation import interpolation_trait
from podpac.core.utils import trait_is_defined
from podpac.core.managers.multi_threading import thread_manager

COMMON_COMPOSITOR_DOC = COMMON_DATA_DOC.copy() # superset of COMMON_NODE_DOC

Expand Down Expand Up @@ -219,27 +219,45 @@ def iteroutputs(self, coordinates):
nc = merge_dims([Coordinates(np.atleast_1d(c), dims=[coords_dim]), self.shared_coordinates])

if trait_is_defined(s, "native_coordinates") is False:
s.set_trait('native_coordinates', nc)
s.set_trait("native_coordinates", nc)

if settings["MULTITHREADING"]:
n_threads = thread_manager.request_n_threads(len(src_subset))
if n_threads == 1:
thread_manager.release_n_threads(n_threads)
else:
n_threads = 0

if settings["MULTITHREADING"] and n_threads > 1:
# TODO pool of pre-allocated scratch space
# TODO: docstring?
def f(src):
return src.eval(coordinates)

pool = ThreadPool(processes=settings.get("N_THREADS", 10))
# Create pool of size n_threads, note, this may be created from a sub-thread (i.e. not the main thread)
pool = thread_manager.get_thread_pool(processes=n_threads)

# Evaluate nodes in parallel/asynchronously
results = [pool.apply_async(f, [src]) for src in src_subset]

# Yield results as they are being requested, blocking when the thread is not finished
for src, res in zip(src_subset, results):
yield res.get()
# src._output = None # free up memory

# This prevents any more tasks from being submitted to the pool, and will close the workers one done
pool.close()

# Release these number of threads back to the thread pool
thread_manager.release_n_threads(n_threads)
self._multi_threaded = True
else:
output = None # scratch space
for src in src_subset:
output = src.eval(coordinates, output)
yield output
# output[:] = np.nan
self._multi_threaded = False

@node_eval
@common_doc(COMMON_COMPOSITOR_DOC)
Expand Down
86 changes: 86 additions & 0 deletions podpac/core/managers/multi_threading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Module for dealing with multi-threaded execution.

This is used to ensure that the total number of threads specified in the settings is not exceeded.

"""

from __future__ import division, unicode_literals, print_function, absolute_import

from multiprocessing import Lock
from multiprocessing.pool import ThreadPool

from podpac.core.settings import settings

DEFAULT_N_THREADS = 10


class ThreadManager(object):
""" This is a singleton class that keeps track of the total number of threads used in an application.
"""

_lock = Lock()
cache_lock = Lock()
_n_threads_used = 0
__instance = None

def __new__(cls):
if ThreadManager.__instance is None:
ThreadManager.__instance = object.__new__(cls)
return ThreadManager.__instance

def request_n_threads(self, n):
""" Returns the number of threads allowed for a pool taking into account all other threads application, as
specified by podpac.settings["N_THREADS"].

Parameters
-----------
n : int
Number of threads requested by operation

Returns
--------
int
Number of threads a pool may use. Note, this may be less than or equal to n, and may be 0.
"""
with self._lock:
available = max(0, settings.get("N_THREADS", DEFAULT_N_THREADS) - self._n_threads_used)
claimed = min(available, n)
self._n_threads_used += claimed
return claimed

def release_n_threads(self, n):
""" This releases the number of threads specified.

Parameters
------------
n : int
Number of threads to be released

Returns
--------
int
Number of threads available after releases 'n' threads
"""
with self._lock:
self._n_threads_used = max(0, self._n_threads_used - n)
available = max(0, settings.get("N_THREADS", DEFAULT_N_THREADS) - self._n_threads_used)
return available

def get_thread_pool(self, processes):
""" Creates a threadpool that can be used to run jobs in parallel.

Parameters
-----------
processes : int
The number of threads or workers that will be part of the pool

Returns
--------
multiprocessing.ThreadPool
An instance of the ThreadPool class
"""
return ThreadPool(processes=processes)


thread_manager = ThreadManager()
11 changes: 8 additions & 3 deletions podpac/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from podpac.core.coordinates import Coordinates
from podpac.core.style import Style
from podpac.core.cache import CacheCtrl, get_default_cache_ctrl, S3CacheStore, make_cache_ctrl
from podpac.core.managers.multi_threading import thread_manager


COMMON_NODE_DOC = {
Expand Down Expand Up @@ -136,6 +137,8 @@ def _validate_units(self, d):
_requested_coordinates = tl.Instance(Coordinates, allow_none=True)
_output = tl.Instance(UnitsDataArray, allow_none=True)
_from_cache = tl.Bool(allow_none=True, default_value=None)
# Flag that is True if the Node was run multi-threaded, or None if the question doesn't apply
_multi_threaded = tl.Bool(allow_none=True, default_value=None)

def __init__(self, **kwargs):
""" Do not overwrite me """
Expand Down Expand Up @@ -478,11 +481,11 @@ def put_cache(self, data, key, coordinates=None, overwrite=False):
NodeException
Cached data already exists (and overwrite is False)
"""

if not overwrite and self.has_cache(key, coordinates=coordinates):
raise NodeException("Cached data already exists for key '%s' and coordinates %s" % (key, coordinates))

self.cache_ctrl.put(self, data, key, coordinates=coordinates, update=overwrite)
with thread_manager.cache_lock:
self.cache_ctrl.put(self, data, key, coordinates=coordinates, update=overwrite)

def has_cache(self, key, coordinates=None):
"""
Expand All @@ -500,7 +503,8 @@ def has_cache(self, key, coordinates=None):
bool
True if there is cached data for this node, key, and coordinates.
"""
return self.cache_ctrl.has(self, key, coordinates=coordinates)
with thread_manager.cache_lock:
return self.cache_ctrl.has(self, key, coordinates=coordinates)

def rem_cache(self, key, coordinates=None, mode=None):
"""
Expand Down Expand Up @@ -807,6 +811,7 @@ def wrapper(self, coordinates, output=None):
self._requested_coordinates = coordinates
key = cache_key
cache_coordinates = coordinates.transpose(*sorted(coordinates.dims)) # order agnostic caching

if self.has_cache(key, cache_coordinates) and not self.cache_update:
data = self.get_cache(key, cache_coordinates)
if output is not None:
Expand Down
6 changes: 3 additions & 3 deletions podpac/datalib/gfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def init(self):
base_time = datetime.datetime.strptime("%s %s" % (self.date, self.hour), "%Y%m%d %H%M")
forecast_times = [base_time + datetime.timedelta(hours=int(h)) for h in self.forecasts]
tc = Coordinates([[dt.strftime("%Y-%m-%d %H:%M") for dt in forecast_times]], dims=["time"])
self.set_trait('native_coordinates', merge_dims([nc, tc]))
self.set_trait("native_coordinates", merge_dims([nc, tc]))

def get_data(self, coordinates, coordinates_index):
data = self.create_output_array(coordinates)
Expand All @@ -125,13 +125,13 @@ def init(self):
now = datetime.datetime.now()

# date
self.set_trait('date', now.strftime("%Y%m%d"))
self.set_trait("date", now.strftime("%Y%m%d"))

# hour
prefix = "%s/%s/%s/" % (self.parameter, self.level, self.date)
objs = bucket.objects.filter(Prefix=prefix)
hours = set(obj.key.split("/")[3] for obj in objs)
if hours:
self.set_trait('hour', max(hours))
self.set_trait("hour", max(hours))

super(GFSLatest, self).init()