diff --git a/Makefile b/Makefile index bdebed0b5ae6..66fc540ea185 100644 --- a/Makefile +++ b/Makefile @@ -61,10 +61,9 @@ ifneq ($(ADD_LDFLAGS), NONE) LDFLAGS += $(ADD_LDFLAGS) endif -#BIN = test/test_threaded_engine test/api_registry_test +BIN = tests/test_simple_engine OBJ = narray_function_cpu.o -# add threaded engine after it is done -OBJCXX11 = reshape_cpu.o engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o +OBJCXX11 = reshape_cpu.o dag_engine.o simple_engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a @@ -82,7 +81,8 @@ $(DMLC_CORE)/libdmlc.a: + cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR) storage.o: src/storage/storage.cc -engine.o: src/dag_engine/simple_engine.cc +dag_engine.o: src/dag_engine/dag_engine.cc +simple_engine.o: src/dag_engine/simple_engine.cc narray.o: src/narray/narray.cc narray_function_cpu.o: src/narray/narray_function.cc src/narray/narray_function-inl.h narray_function_gpu.o: src/narray/narray_function.cu src/narray/narray_function-inl.h @@ -111,7 +111,8 @@ iter_mnist.o: src/io/iter_mnist.cc lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) $(LIB_DEP) lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) $(LIB_DEP) -test/test_storage: test/test_storage.cc lib/libmxnet.a +tests/test_storage: tests/test_storage.cc lib/libmxnet.a +tests/test_simple_engine: tests/test_simple_engine.cc lib/libmxnet.a $(BIN) : $(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) @@ -125,7 +126,7 @@ $(OBJCXX11) : $(SLIB) : $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) -$(ALIB): +$(ALIB): $(OBJ) $(OBJCXX11) ar cr $@ $+ $(CUOBJ) : diff --git a/doc/Doxyfile b/doc/Doxyfile index 407ea96e95ab..aeef012f2384 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -1925,7 +1925,7 @@ PERLMOD_MAKEVAR_PREFIX = # C-preprocessor directives found in the sources and include files. # The default value is: YES. -ENABLE_PREPROCESSING = NO +ENABLE_PREPROCESSING = YES # If the MACRO_EXPANSION tag is set to YES doxygen will expand all macro names # in the source code. If set to NO only conditional compilation will be diff --git a/include/mxnet/dag_engine.h b/include/mxnet/dag_engine.h index 18b804b5a2d8..f2bf063a1bb1 100644 --- a/include/mxnet/dag_engine.h +++ b/include/mxnet/dag_engine.h @@ -1,107 +1,162 @@ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2015 by Contributors * \file dag_engine.h - * \brief dynamic data-flow dag engine that schedules - * operations in a concurrent way + * \brief DAG engine that schedules data. */ #ifndef MXNET_DAG_ENGINE_H_ #define MXNET_DAG_ENGINE_H_ #include -// check c++11 + #if DMLC_USE_CXX11 == 0 -#error "cxx11 was required for narray module" +#error "C++11 was required for DAG engine module." #endif #include #include -#include "./base.h" -#include "./context.h" +#include "base.h" +#include "context.h" namespace mxnet { + +/*! + * \brief Namespace of engine implementation. + */ +namespace engine { + /*! - * \brief dynamic data-flow dag engine that schedules - * operations in a concurrent way + * \brief Inner representation of variable. + */ +struct Var; + +/*! + * \brief Inner representation of operator. + */ +struct Opr; + +} // namespace engine + +/*! + * \brief Dynamic dataflow DAG engine that schedules operations. */ class DAGEngine { public: /*! - * \brief operation to pass to DAG engine - * \param ctx runtime context + * \brief Operation to pass to DAG engine. + */ + using Fn = std::function; + /*! + * \brief Callback function to notify operation complete. + */ + using Callback = std::function; + /*! + * \brief Asynchronous operation to pass to DAG engine. + */ + using AsyncFn = std::function; + /*! + * \brief Variable of dag engine, used to specify dependencies defined to be a + * pointer, that points to an internal data structure of the engine + * itself. + */ + using Variable = engine::Var*; + /*! + * \brief Operator of the engine. + */ + using OprHandle = engine::Opr*; + /*! + * \brief Allocate a new variable, the variable can then + * be used to schedule the operation concurrently via dependency + * patterns. + * \return The new variable allocated. + */ + virtual Variable NewVar() = 0; + /*! + * \brief Create a new operator. The returned operator could be saved + * externally so that it could be resued for scheduling. + * \param fn The execution function. + * \param use_vars The variables that current operation will use but not + * mutate. + * \param mutate_vars Teh variables that current operation will mutate. + * \return The new operator allocated. + */ + virtual OprHandle NewOperator(AsyncFn fn, + std::vector const& use_vars, + std::vector const& mutate_vars) = 0; + /*! + * \brief Delete the given operator. + * \param op The operator to delete. */ - typedef std::function Op; - /*! \brief callback function to notify operation complete */ - typedef std::function Callback; + virtual void DeleteOperator(OprHandle op) = 0; /*! - * \brief operation to pass to DAG engine - * \param ctx runtime context - * \param on_complete a callback function used to notify the engine the action completes + * \brief Push an operator to the engine. + * \param op The operator to push. + * \param exec_ctx Execution context. */ - typedef std::function AsyncOp; + virtual void Push(OprHandle op, Context exec_ctx) = 0; /*! - * \brief variable of dag engine, used to specify dependencies - * defined to be a pointer, that can points to a internal data structure - * of the engine itself + * \brief Push an synchronous operation to the DAG engine. + * \param exec_fun Execution function that executes the operation. + * \param exec_ctx Execution context. + * \param use_vars The variables that current operation will use but not + * mutate. + * \param mutate_vars The variables that current operation will mutate. + */ + void Push(Fn exec_fun, Context exec_ctx, + std::vector const& use_vars, + std::vector const& mutate_vars); + /*! + * \brief Push an asynchronous operation to the DAG engine. + * \param exec_fun Execution function, this function takes a parameter + * on_complete that must be called when the execution + * completes. + * \param exec_ctx Execution context. + * \param use_vars The variables that current operation will use but not + * mutate. + * \param mutate_vars The variables that current operation will mutate. + */ + virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, + std::vector const& use_vars, + std::vector const& mutate_vars) = 0; + /*! + * \brief Schedule the delete of a variable. * - * Design detail: we choose pointer instead of some ID to avoid - * indirect map lookup. usually, Variable directly points to the content we need - */ - typedef void *Variable; - /*! - * \brief Push an asynchronize operation to the DAG engine - * \param exec_fun execution funtion, this function takes a parameter on_complete - * that must be called when the execution completes. For synchronize operations - * \param exec_ctx execution context - * \param use_vars the variables that current operation will use(but not mutate) - * \param mutate_vars the variables that current operation will mutate - */ - virtual void PushAsync(AsyncOp exec_fun, - Context exec_ctx, - const std::vector &use_vars, - const std::vector &mutate_vars) = 0; - /*! - * \brief Push an synchronize operation to the DAG engine - * \param exec_fun execution funtion that executes the operation - * \param exec_ctx execution context - * \param use_vars the variables that current operation will use(but not mutate) - * \param mutate_vars the variables that current operation will mutate - */ - virtual void Push(Op exec_fun, - Context exec_ctx, - const std::vector &use_vars, - const std::vector &mutate_vars) { - AsyncOp f = [exec_fun](RunContext ctx, Callback on_complete) { - exec_fun(ctx); on_complete(); - }; - this->PushAsync(f, exec_ctx, use_vars, mutate_vars); - } - /*! - * \brief schedule the delete of variable var, - * The delete will not happen immediately, but will wait until all the operations - * depending on var is completed + * The delete will not happen immediately, but will wait until all the + * operations depending on var is completed. * - * \param delete_fun a function that will be called after var is deleted - * \param exec_ctx execution context - * \param var the variable to be deleted + * \param delete_fun A function that will be called after the variable is + * deleted. + * \param exec_ctx Execution context. + * \param var The variable to be deleted. */ - virtual void PushDelete(Op delete_fun, - Context exec_ctx, - Variable var) = 0; + virtual void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) = 0; /*! - * \brief allocate a new variable, the variable can then - * be used to schedul the operation concurrently via dependency patterns - * \return thew new variable allocated + * \brief Wait for variable. + * \param var The variable we should wait for, this function returns when all + * the operations related to var has been completed. */ - virtual Variable NewVar() = 0; + virtual void WaitForVar(Variable var) = 0; /*! - * \brief wait for variable var - * \param var the variable we should wait for, this function returns when all the operations - * related to var has been completed + * \brief Wait until all the activity of dag engine finishes. */ - virtual void WaitForVar(Variable var) {} - /*! \brief wait until all the activity of dag engine finishes */ - virtual void WaitForAll() {} - /*! \return DAG engine singleton */ - static DAGEngine *Get(); -}; + virtual void WaitForAll() = 0; + /*! + * \brief Virtual destructor. + */ + virtual ~DAGEngine() noexcept(false); + /*! + * \return DAG engine singleton. + */ + static DAGEngine* Get(); + + protected: + /*! + * \brief Hidden constructors. + */ + DAGEngine(); + + private: + DISALLOW_COPY_AND_ASSIGN(DAGEngine); +}; // class DAGEngine + } // namespace mxnet + #endif // MXNET_DAG_ENGINE_H_ diff --git a/src/common/concurrent_blocking_queue.h b/src/common/concurrent_blocking_queue.h deleted file mode 100644 index 82e2598816a5..000000000000 --- a/src/common/concurrent_blocking_queue.h +++ /dev/null @@ -1,120 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file concurrent_blocking_queue.h - * \brief A simple lock-based consumer-producer queue. - */ -#ifndef MXNET_COMMON_CONCURRENT_BLOCKING_QUEUE_H_ -#define MXNET_COMMON_CONCURRENT_BLOCKING_QUEUE_H_ - -#include -#include -#include -#include -#include -#include - -/*! - * \brief Common components. - */ -namespace common { - -/*! - * \brief A simple lock-based consumer-producer queue. - */ -template class ConcurrentBlockingQueue { - static const int kBusyLoop = 1000; - - public: - ConcurrentBlockingQueue() : has_elmt_(false), exit_now_(false) { - } - /*! - * \brief Push object into the queue. Notify anyone who is waiting. - * \param e the object - */ - void Push(const T& e) { - std::lock_guard lock(mutex_); - has_elmt_ = true; - queue_.push_back(e); - if (queue_.size() == 1) { - cv_.notify_all(); - } - } - /*! - * \brief Pop object out of the queue. If the queue is empty, the caller thread will sleep until - * (1) Producer pushed some product into the queue and the caller thread wins it. - * (2) A kill signal is passed to the queue. - * \param rv the pointer point to the return object - * \return whether an object is returned - */ - bool Pop(T* rv) { - for (int i = 0; i < kBusyLoop; i++) { - if (has_elmt_) { - std::lock_guard lock(mutex_); - if (!has_elmt_) { - assert(queue_.empty()); - continue; - } - *rv = queue_.front(); - queue_.pop_front(); - if (queue_.empty()) - has_elmt_ = false; - return false; - } - } - { - std::unique_lock lock(mutex_); - while (queue_.empty() && !exit_now_) { - cv_.wait(lock); - } - if (!exit_now_) { - *rv = queue_.front(); - queue_.pop_front(); - if (queue_.empty()) - has_elmt_ = false; - return false; - } else { - return true; - } - } - } - /*! - * \brief pop all objects in the queue. - * \return a list containing all objects in the queue. - */ - std::list PopAll() { - std::lock_guard lock(mutex_); - std::list rv; - rv.swap(queue_); - return rv; - } - /*! - * \brief tell the queue to release all waiting consumers - */ - void SignalForKill() { - std::unique_lock lock(mutex_); - exit_now_ = true; - cv_.notify_all(); - } - /*! - * \brief return the current queue size - * \return queue size - */ - size_t QueueSize() { - std::unique_lock lock(mutex_); - return queue_.size(); - } - - private: - std::atomic has_elmt_; - std::list queue_; - std::mutex mutex_; - std::condition_variable cv_; - std::atomic exit_now_; - - ConcurrentBlockingQueue(const ConcurrentBlockingQueue&) = delete; - ConcurrentBlockingQueue& operator=(const ConcurrentBlockingQueue&) = delete; -}; - -} // namespace common - -#endif // MXNET_COMMON_CONCURRENT_BLOCKING_QUEUE_H_ diff --git a/src/common/spin_lock.h b/src/common/spin_lock.h deleted file mode 100644 index 60850f171ecf..000000000000 --- a/src/common/spin_lock.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2015 by Contributors - * Spin lock using xchg. - * Copied from http://locklessinc.com/articles/locks/ - */ - -#ifndef MXNET_COMMON_SPIN_LOCK_H_ -#define MXNET_COMMON_SPIN_LOCK_H_ - -/* Compile read-write barrier */ -#define barrier() asm volatile("": : :"memory") - -/* Pause instruction to prevent excess processor bus usage */ -#define cpu_relax() asm volatile("pause\n": : :"memory") - -static inline unsigned short xchg_8(void *ptr, unsigned char x) { // NOLINT(*) - __asm__ __volatile__("xchgb %0,%1" - :"=r" (x) - :"m" (*(volatile unsigned char *)ptr), "0" (x) - :"memory"); - - return x; -} - -#define BUSY 1 -typedef unsigned char spinlock; - -/*! - * \brief use this value to initialize lock object - */ -#define SPINLOCK_INITIALIZER 0 - -/*! - * \brief lock - * \param lock the pointer to lock object - */ -static inline void spin_lock(spinlock *lock) { - while (1) { - if (!xchg_8(lock, BUSY)) return; - - while (*lock) cpu_relax(); - } -} - -/*! - * \brief unlock - * \param lock the pointer to lock object - */ -static inline void spin_unlock(spinlock *lock) { - barrier(); - *lock = 0; -} - -/*! - * \brief try lock - * \param lock the pointer to lock object - * \return whether the lock is grabbed or not - */ -static inline int spin_trylock(spinlock *lock) { - return xchg_8(lock, BUSY); -} - -#endif // MXNET_COMMON_SPIN_LOCK_H_ diff --git a/src/dag_engine/dag_engine.cc b/src/dag_engine/dag_engine.cc new file mode 100644 index 000000000000..773971c5db6b --- /dev/null +++ b/src/dag_engine/dag_engine.cc @@ -0,0 +1,34 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#include "mxnet/dag_engine.h" +#include "simple_engine.h" +#include "dag_engine_impl.h" + +namespace mxnet { + +void DAGEngine::Push(Fn exec_fun, Context exec_ctx, + std::vector const& use_vars, + std::vector const& mutate_vars) { + auto f = [exec_fun](RunContext ctx, Callback on_complete) { + exec_fun(ctx); + on_complete(); + }; + PushAsync(f, exec_ctx, use_vars, mutate_vars); +} + +DAGEngine::~DAGEngine() noexcept(false) {} + +DAGEngine::DAGEngine() = default; + +DAGEngine* DAGEngine::Get() { + /*! + * \brief Change specific engine to use. + */ + using EngineImplementation = engine::SimpleEngine; + + static EngineImplementation inst; + return &inst; +} + +} // namespace mxnet diff --git a/src/dag_engine/dag_engine_impl.h b/src/dag_engine/dag_engine_impl.h new file mode 100644 index 000000000000..aa090074710b --- /dev/null +++ b/src/dag_engine/dag_engine_impl.h @@ -0,0 +1,56 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_DAG_ENGINE_DAG_ENGINE_IMPL_H_ +#define MXNET_DAG_ENGINE_DAG_ENGINE_IMPL_H_ + +#include +#include "mxnet/dag_engine.h" + +// #define DAG_ENGINE_DEBUG + +namespace mxnet { +namespace engine { + +struct Var { +#ifdef DAG_ENGINE_DEBUG + virtual ~Var() = default; +#endif // DAG_ENGINE_DEBUG + template + T* Cast(); +}; // struct Var + +struct Opr { +#ifdef DAG_ENGINE_DEBUG + virtual ~Opr() = default; +#endif // DAG_ENGINE_DEBUG + template + T* Cast(); +}; // struct Opr + +template +T* Var::Cast() { + static_assert(std::is_base_of::value, + "must inherit `mxnet::engine::Var`"); +#ifndef DAG_ENGINE_DEBUG + return static_cast(this); +#else // DAG_ENGINE_DEBUG + return dynamic_cast(this); +#endif // DAG_ENGINE_DEBUG +} + +template +T* Opr::Cast() { + static_assert(std::is_base_of::value, + "must inherit `mxnet::engine::Opr`"); +#ifndef DAG_ENGINE_DEBUG + return static_cast(this); +#else // DAG_ENGINE_DEBUG + return dynamic_cast(this); +#endif // DAG_ENGINE_DEBUG +} + +} // namespace engine +} // namespace mxnet + +#endif // MXNET_DAG_ENGINE_DAG_ENGINE_IMPL_H_ diff --git a/src/dag_engine/object_pool.h b/src/dag_engine/object_pool.h new file mode 100644 index 000000000000..1257cb540ccc --- /dev/null +++ b/src/dag_engine/object_pool.h @@ -0,0 +1,108 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_DAG_ENGINE_OBJECT_POOL_H_ +#define MXNET_DAG_ENGINE_OBJECT_POOL_H_ +#include +#include +#include "common.h" + +template +class SmallObjectPool { + public: + struct LinkedList { + union { + LinkedList* next{nullptr}; + T t; + }; + }; + ~SmallObjectPool() = default; + T* New(); + void Delete(T* ptr); + + static SmallObjectPool* Get(); + + private: + constexpr static std::size_t kPageSize = 1 << 12; + std::recursive_mutex m_; + LinkedList* head_{nullptr}; + SmallObjectPool(); + void AllocateChunk(); + + SmallObjectPool(SmallObjectPool const&) = delete; + SmallObjectPool(SmallObjectPool&&) = delete; + SmallObjectPool& operator=(SmallObjectPool const&) = delete; + SmallObjectPool& operator=(SmallObjectPool&&) = delete; +}; + +template +T* SmallObjectPool::New() { + LinkedList* ret; + { + std::lock_guard lock{m_}; + if (head_->next == nullptr) { + AllocateChunk(); + } + ret = head_; + head_ = head_->next; + } + return new(static_cast(ret)) T{}; +} + +template +void SmallObjectPool::Delete(T* ptr) { + ptr->~T(); + auto linked_list_ptr = reinterpret_cast(ptr); + { + std::lock_guard lock{m_}; + linked_list_ptr->next = head_; + head_ = linked_list_ptr; + } +} + +template +SmallObjectPool* SmallObjectPool::Get() { + static SmallObjectPool inst; + return &inst; +} + +template +SmallObjectPool::SmallObjectPool() { + AllocateChunk(); +} + +template +void SmallObjectPool::AllocateChunk() { + std::lock_guard lock{m_}; + static_assert(kPageSize % sizeof(LinkedList) == 0, + "Could not align to page size."); + auto&& new_chunk = static_cast(malloc(kPageSize)); + auto size = kPageSize / sizeof(LinkedList); + for (std::size_t i = 0 ; i < size - 1; ++i) { + new_chunk[i].next = &new_chunk[i + 1]; + } + new_chunk[size - 1].next = head_; + head_ = new_chunk; +} + + +struct A { + A() { + LOG("constructing"); + } + ~A() { + LOG("destructing"); + } +}; + +int main() { + auto&& pool = SmallObjectPool::Get(); + auto a = pool->New(); + auto b = pool->New(); + LOG("addresses %p %p", a, b); + pool->Delete(a); + a = pool->New(); + LOG("address again %p", a); + return 0; +} +#endif // MXNET_DAG_ENGINE_OBJECT_POOL_H_ diff --git a/src/dag_engine/simple_engine.cc b/src/dag_engine/simple_engine.cc index d2ec5cfd3c7e..36eea89fd849 100644 --- a/src/dag_engine/simple_engine.cc +++ b/src/dag_engine/simple_engine.cc @@ -1,52 +1,277 @@ -// Copyright (c) 2015 by Contributors -#include +/*! + * Copyright (c) 2015 by Contributors + */ +#include "simple_engine.h" #include -#include +#include +#include +#include +#include +#include +#include "../common/cuda_utils.h" + namespace mxnet { -class SimpleEngine : public DAGEngine { - public: - virtual void PushAsync(AsyncOp exec_fun, - Context exec_ctx, - const std::vector &use_vars, - const std::vector &mutate_vars) { - // cannot schedule async using naive way because deps are not captured - LOG(FATAL) << "cannot schedule async operations"; - } - virtual void Push(Op exec_fun, - Context exec_ctx, - const std::vector &use_vars, - const std::vector &mutate_vars) { - if (exec_ctx.dev_mask == gpu::kDevMask) { -#if MXNET_USE_CUDA - ctx_.stream = &stream; - mshadow::SetDevice(exec_ctx.dev_id); - exec_fun(ctx_); -#else - LOG(FATAL) << "GPU is not enabled"; -#endif + +namespace engine { + +#ifdef DAG_ENGINE_DEBUG +std::atomic OprBlock::counter{0}; +std::atomic VersionedVarBlock::counter{0}; +std::atomic SimpleVar::counter{0}; +std::atomic SimpleOpr::counter{0}; +#endif // DAG_ENGINE_DEBUG + +SimpleVar* SimpleVar::CastFromBase(Var* v) { return v->Cast(); } + +SimpleOpr* SimpleOpr::CastFromBase(Opr* o) { return o->Cast(); } + +SimpleEngine::SimpleEngine() + : pending_{0}, thread_pool_{[this]() { ThreadWorker(); }} {} + +SimpleEngine::~SimpleEngine() noexcept(false) { task_queue_.SignalForKill(); } + +SimpleVar* SimpleEngine::NewVar() { + auto ret = new SimpleVar{}; + ret->head = new VersionedVarBlock{}; + return ret; +} + +SimpleOpr* SimpleEngine::NewOperator(SimpleEngine::AsyncFn fn, + std::vector const& use_vars, + std::vector const& mutate_vars) { + auto ret = new SimpleOpr{}; + ret->fn = fn; + ret->use_vars.resize(use_vars.size()); + ret->mutate_vars.resize(mutate_vars.size()); + std::transform(use_vars.begin(), use_vars.end(), ret->use_vars.begin(), + SimpleVar::CastFromBase); + std::transform(mutate_vars.begin(), mutate_vars.end(), + ret->mutate_vars.begin(), SimpleVar::CastFromBase); +#ifdef DAG_ENGINE_DEBUG + // Check for duplicates. + auto use = use_vars; + auto mutate = mutate_vars; + auto use_size = use.size(); + auto mutate_size = mutate.size(); + std::sort(use.begin(), use.end()); + std::sort(mutate.begin(), mutate.end()); + for (std::size_t i = 0; i < use_size; ++i) { + if (i != 0 && use.at(i) == use.at(i - 1)) { + LOG(FATAL) << "duplicate items found in `use_vars`"; + } + } + for (std::size_t i = 0; i < mutate_size; ++i) { + if (i != 0 && mutate.at(i) == mutate.at(i - 1)) { + LOG(FATAL) << "duplicate items found in `mutate_vars`"; + } + } + std::size_t j = 0; + for (std::size_t i = 0; i < use_size; ++i) { + while (j < mutate_size && mutate.at(j) < use.at(i)) { + ++j; + } + if (j == mutate_size) { + break; + } + if (mutate.at(j) == use.at(i)) { + LOG(FATAL) + << "duplicate items found between `use_vars` and `mutate_vars`"; + } + } +#endif // DAG_ENGINE_DEBUG + return ret; +} + +void SimpleEngine::DeleteOperator(OprHandle op) { + auto&& simple_opr = SimpleOpr::CastFromBase(op); + std::vector deps{}; + deps.reserve(simple_opr->use_vars.size() + simple_opr->mutate_vars.size()); + deps.insert(deps.end(), simple_opr->use_vars.begin(), + simple_opr->use_vars.end()); + deps.insert(deps.end(), simple_opr->mutate_vars.begin(), + simple_opr->mutate_vars.end()); + auto&& func = [simple_opr](RunContext) { delete simple_opr; }; + Push(func, Context{}, {}, deps); +} + +void SimpleEngine::Push(OprHandle op, Context exec_ctx) { + auto&& simple_opr = SimpleOpr::CastFromBase(op); + auto&& opr_block = new OprBlock{}; + opr_block->opr = simple_opr; + opr_block->wait.store(simple_opr->use_vars.size() + + simple_opr->mutate_vars.size() + 1); + opr_block->ctx = exec_ctx; + opr_block->rctx = RunContext{nullptr}; + ++pending_; + // Add read dependencies. + for (auto&& i : simple_opr->use_vars) { + std::lock_guard lock{i->m}; + if (i->ready_to_read) { + assert(i->pending_write == nullptr); + ++i->num_pending_reads; + --opr_block->wait; } else { - exec_fun(ctx_); + auto&& new_var_block = new VersionedVarBlock{}; + assert(i->head->next == nullptr); + assert(i->head->trigger == nullptr); + assert(i->head->write == false); + i->head->next = new_var_block; + i->head->trigger = opr_block; + i->head = new_var_block; } } - virtual void PushDelete(Op delete_fun, - Context exec_ctx, - Variable var) { - this->Push(delete_fun, exec_ctx, {}, {var}); + // Add write dependencies. + for (auto&& i : simple_opr->mutate_vars) { + std::lock_guard lock{i->m}; + auto&& new_var_block = new VersionedVarBlock{}; + i->head->next = new_var_block; + i->head->trigger = opr_block; + i->head->write = true; + if (i->ready_to_read) { + /*! + * Raise `num_pending_reads` temporarily to avoid premature triggering. + */ + ++i->num_pending_reads; + i->pending_write = i->head; + if (--i->num_pending_reads == 0) { + --opr_block->wait; + } + i->ready_to_read = false; + } + i->head = new_var_block; } - virtual Variable NewVar() { - // in practice return a ptr to a cell - // that have the info about the variable - // use ptr directly instead of ID because this avoids an indirect mapping - return NULL; + if (--opr_block->wait == 0) { + task_queue_.Push(opr_block); } +} + +void SimpleEngine::PushAsync(AsyncFn fn, Context exec_ctx, + std::vector const& use_vars, + std::vector const& mutate_vars) { + auto&& opr = NewOperator(fn, use_vars, mutate_vars); + opr->temporary = true; + Push(opr, exec_ctx); +} + +void SimpleEngine::PushDelete(Fn delete_fn, Context exec_ctx, Variable var) { + auto&& simple_var = SimpleVar::CastFromBase(var); + auto&& func = [delete_fn, simple_var](RunContext ctx) { + /*! + * Mark variable as orphan, so during `SimpleEngine::OnComplete` it could be + * recycled. + */ + simple_var->to_delete = true; + delete_fn(ctx); + }; + Push(func, exec_ctx, {}, {var}); +} + +void SimpleEngine::WaitForVar(Variable var) { + std::unique_lock lock{finished_m_}; + std::atomic done{false}; + auto&& callback = [this, &done](RunContext) { + std::unique_lock lock{finished_m_}; + done.store(true); + finished_cv_.notify_all(); + }; + Push(callback, Context{}, {var}, {}); + finished_cv_.wait(lock, [&done]() { return done.load(); }); +} - private: - RunContext ctx_; - mshadow::Stream stream; -}; -// implements the singleton factory -DAGEngine* DAGEngine::Get() { - static SimpleEngine engine; - return &engine; +void SimpleEngine::WaitForAll() { + std::unique_lock lock{finished_m_}; + finished_cv_.wait(lock, [this]() { return pending_.load() == 0; }); } + +void SimpleEngine::OnComplete(SimpleOpr* simple_opr) { + /*! + * Mark complete for read variables. + */ + for (auto&& i : simple_opr->use_vars) { + std::lock_guard lock{i->m}; + if (--i->num_pending_reads == 0) { + if (i->pending_write != nullptr && + --i->pending_write->trigger->wait == 0) { + task_queue_.Push(i->pending_write->trigger); + } + } + } + /*! + * Mark complete for write variables. + */ + for (auto&& i : simple_opr->mutate_vars) { + bool to_delete = false; + { + std::lock_guard lock{i->m}; + assert(i->ready_to_read == false); + auto head = i->pending_write->next; + delete i->pending_write; + i->pending_write = nullptr; + if (i->to_delete) { + assert(head->next == nullptr); + delete head; + to_delete = true; + } else { + while (true) { + if (head->write == true) { + ++i->num_pending_reads; + i->pending_write = head; + if (--i->num_pending_reads == 0) { + if (--head->trigger->wait == 0) { + task_queue_.Push(head->trigger); + } + } + break; + } else if (head->next == nullptr) { + i->ready_to_read = true; + break; + } else { + ++i->num_pending_reads; + if (--head->trigger->wait == 0) { + task_queue_.Push(head->trigger); + } + auto prev = head; + head = head->next; + delete prev; + } + } + } + } + if (to_delete) { + delete i; + } + } + { + std::unique_lock lock{finished_m_}; + if (--pending_ == 0) { + finished_cv_.notify_all(); + } + } +} + +void SimpleEngine::ThreadWorker() { + OprBlock* opr_block; + while (task_queue_.Pop(&opr_block)) { + assert(opr_block->wait.load() == 0); + auto simple_opr = opr_block->opr; + auto callback = [this, simple_opr]() { + OnComplete(simple_opr); + if (simple_opr->temporary) { + delete simple_opr; + } + }; + if (opr_block->ctx.dev_mask == gpu::kDevMask) { +#if MXNET_USE_CUDA + CUDA_CALL(cudaSetDevice(opr_block->ctx.dev_id)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + } + simple_opr->fn(opr_block->rctx, callback); + delete opr_block; + } +} + +} // namespace engine + } // namespace mxnet diff --git a/src/dag_engine/simple_engine.h b/src/dag_engine/simple_engine.h new file mode 100644 index 000000000000..21af488f1c1d --- /dev/null +++ b/src/dag_engine/simple_engine.h @@ -0,0 +1,168 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_DAG_ENGINE_SIMPLE_ENGINE_H_ +#define MXNET_DAG_ENGINE_SIMPLE_ENGINE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "dag_engine_impl.h" +#include "thread_pool.h" + +namespace mxnet { + +namespace engine { + +/*! + * \brief Forward declarations. + */ +struct SimpleOpr; + +/*! + * \brief Operation in the queue. + */ +struct OprBlock { +#ifdef DAG_ENGINE_DEBUG + static std::atomic counter; + OprBlock() { LOG(INFO) << __func__ << " " << ++counter; } + ~OprBlock() { LOG(INFO) << __func__ << " " << --counter; } +#endif // DAG_ENGINE_DEBUG + std::atomic wait{0}; + SimpleOpr* opr{nullptr}; + Context ctx; + RunContext rctx; +}; // struct OprBlock + +/*! + * \brief Variable with version information. + */ +struct VersionedVarBlock { +#ifdef DAG_ENGINE_DEBUG + static std::atomic counter; + VersionedVarBlock() { LOG(INFO) << __func__ << " " << ++counter; } + ~VersionedVarBlock() { LOG(INFO) << __func__ << " " << --counter; } +#endif // DAG_ENGINE_DEBUG + VersionedVarBlock* next{nullptr}; + OprBlock* trigger{nullptr}; + bool write{false}; +}; // struct VersionedVarBlock + +/*! + * \brief Variable implementation. + */ +struct SimpleVar final : public Var { +#ifdef DAG_ENGINE_DEBUG + static std::atomic counter; + SimpleVar() { LOG(INFO) << __func__ << " " << ++counter; } + ~SimpleVar() { LOG(INFO) << __func__ << " " << --counter; } +#endif // DAG_ENGINE_DEBUG + std::mutex m; + std::size_t num_pending_reads{0}; + VersionedVarBlock* head{nullptr}; + VersionedVarBlock* pending_write{nullptr}; + /*! + * If true, then there are no current or future processing of the chain. + */ + bool ready_to_read{true}; + /*! + * If true, delete after operation completes. + */ + bool to_delete{false}; + + static SimpleVar* CastFromBase(Var* ptr); +}; // struct SimpleVar + +/*! + * \brief Operator implementation. + */ +struct SimpleOpr final : public Opr { +#ifdef DAG_ENGINE_DEBUG + static std::atomic counter; + SimpleOpr() { LOG(INFO) << __func__ << " " << ++counter; } + ~SimpleOpr() { LOG(INFO) << __func__ << " " << --counter; } +#endif // DAG_ENGINE_DEBUG + DAGEngine::AsyncFn fn; + std::vector use_vars; + std::vector mutate_vars; + bool temporary{false}; + + static SimpleOpr* CastFromBase(Opr* ptr); +}; // struct SimpleOpr + +/*! + * \brief Engine implementation. + */ +class SimpleEngine final : public DAGEngine { + public: + /*! + * \brief Constructor and destructor. + */ + SimpleEngine(); + ~SimpleEngine() noexcept(false); + /*! + * \brief Overriding methods. + */ + SimpleVar* NewVar() override; + SimpleOpr* NewOperator(AsyncFn fn, std::vector const& use_vars, + std::vector const& mutate_vars) override; + void DeleteOperator(OprHandle op) override; + void Push(OprHandle op, Context exec_ctx) override; + using DAGEngine::Push; + void PushAsync(AsyncFn exec_fun, Context exec_ctx, + std::vector const& use_vars, + std::vector const& mutate_vars) override; + void PushDelete(Fn delete_fn, Context exec_ctx, Variable var) override; + void WaitForVar(Variable var) override; + void WaitForAll() override; + /*! + * \brief Callback on operation completion. + * + * On operation completion, this will trigger subsequent operations. + */ + void OnComplete(SimpleOpr* simple_opr); + /*! + * \brief Worker. + * + * The method to pass to thread pool to parallelize. + */ + void ThreadWorker(); + + private: + /*! + * \brief Concurrency for thread pool. + */ + static constexpr std::size_t kNumWorkingThreads = 16; + /*! + * \brief Number of pending operations. + */ + std::atomic pending_; + /*! + * \brief Notify waits for single or all variables. + */ + std::mutex finished_m_; + std::condition_variable finished_cv_; + /*! + * \brief Task queue. + */ + dmlc::ConcurrentBlockingQueue task_queue_; + /*! + * \brief Thread pool. + */ + ThreadPool thread_pool_; + /*! + * \brief Disallow copy construction and assignment. + */ + DISALLOW_COPY_AND_ASSIGN(SimpleEngine); +}; // class SimpleEngine + +} // namespace engine + +} // namespace mxnet + +#endif // MXNET_DAG_ENGINE_SIMPLE_ENGINE_H_ diff --git a/src/dag_engine/thread_pool.h b/src/dag_engine/thread_pool.h new file mode 100644 index 000000000000..4d5b67cc56a3 --- /dev/null +++ b/src/dag_engine/thread_pool.h @@ -0,0 +1,67 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_DAG_ENGINE_THREAD_POOL_H_ +#define MXNET_DAG_ENGINE_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include "mxnet/base.h" + +namespace mxnet { + +namespace engine { + +/*! + * \brief Thread pool. + */ +template +class ThreadPool { + public: + /*! + * \brief Constructor takes function to run and its arguments. + */ + template + explicit ThreadPool(Function&& func, Args&&... args); + /*! + * \brief Destructor. + */ + ~ThreadPool() noexcept(false); + + private: + /*! + * \brief Worker threads. + */ + std::array worker_threads_; + /*! + * \brief Disallow default construction. + */ + ThreadPool() = delete; + /*! + * \brief Disallow copy construction and assignment. + */ + DISALLOW_COPY_AND_ASSIGN(ThreadPool); +}; + +template +template +ThreadPool::ThreadPool(Function&& func, Args&&... args) { + for (auto&& i : worker_threads_) { + i = std::thread{std::forward(func), std::forward(args)...}; + } +} + +template +ThreadPool::~ThreadPool() noexcept(false) { + for (auto&& i : worker_threads_) { + i.join(); + } +} + +} // namespace engine +} // namespace mxnet + +#endif // MXNET_DAG_ENGINE_THREAD_POOL_H_ diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index ff4eb19dc410..f74d73ec8e44 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -90,7 +90,7 @@ class GraphExecutor : public Executor { // all the information needed to push the op to engine struct OpExecEntry { // execution function for - DAGEngine::Op exec_fun; + DAGEngine::Fn exec_fun; // variables to read from std::vector use_vars; // variables to mutate diff --git a/tests/test_simple_engine.cc b/tests/test_simple_engine.cc new file mode 100644 index 000000000000..c65e847bf1ca --- /dev/null +++ b/tests/test_simple_engine.cc @@ -0,0 +1,113 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#include +#include +#include +#include +#include +#include + +#include "mxnet/dag_engine.h" + +void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); } + +int main() { + auto&& engine = mxnet::DAGEngine::Get(); + auto&& var = engine->NewVar(); + std::vector oprs; + + // Test #1 + printf("============= Test #1 ==============\n"); + for (int i = 0; i < 10; ++i) { + oprs.push_back(engine->NewOperator( + [i](mxnet::RunContext ctx, mxnet::DAGEngine::Callback cb) { + Foo(ctx, i); + std::this_thread::sleep_for(std::chrono::seconds{1}); + cb(); + }, + {var}, {})); + engine->Push(oprs.at(i), mxnet::Context{}); + } + engine->WaitForAll(); + printf("Going to push delete\n"); + // std::this_thread::sleep_for(std::chrono::seconds{1}); + for (auto&& i : oprs) { + engine->DeleteOperator(i); + } + engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); + engine->WaitForAll(); + + printf("============= Test #2 ==============\n"); + var = engine->NewVar(); + oprs.clear(); + for (int i = 0; i < 10; ++i) { + oprs.push_back(engine->NewOperator( + [i](mxnet::RunContext ctx, mxnet::DAGEngine::Callback cb) { + Foo(ctx, i); + std::this_thread::sleep_for(std::chrono::milliseconds{500}); + cb(); + }, + {}, {var})); + engine->Push(oprs.at(i), mxnet::Context{}); + } + // std::this_thread::sleep_for(std::chrono::seconds{1}); + engine->WaitForAll(); + for (auto&& i : oprs) { + engine->DeleteOperator(i); + } + engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); + + printf("============= Test #3 ==============\n"); + var = engine->NewVar(); + oprs.clear(); + engine->WaitForVar(var); + engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); + engine->WaitForAll(); + + printf("============= Test #4 ==============\n"); + var = engine->NewVar(); + oprs.clear(); + oprs.push_back(engine->NewOperator( + [](mxnet::RunContext ctx, mxnet::DAGEngine::Callback cb) { + std::this_thread::sleep_for(std::chrono::seconds{2}); + Foo(ctx, 42); + cb(); + }, + {}, {var})); + engine->Push(oprs.at(0), mxnet::Context{}); + LOG(INFO) << "Operator pushed, should wait for 2 seconds."; + engine->WaitForVar(var); + LOG(INFO) << "OK, here I am."; + for (auto&& i : oprs) { + engine->DeleteOperator(i); + } + engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); + engine->WaitForAll(); + + printf("============= Test #5 ==============\n"); + var = engine->NewVar(); + oprs.clear(); + oprs.push_back(engine->NewOperator( + [](mxnet::RunContext ctx, mxnet::DAGEngine::Callback cb) { + Foo(ctx, 42); + std::this_thread::sleep_for(std::chrono::seconds{2}); + cb(); + }, + {var}, {})); + engine->Push(oprs.at(0), mxnet::Context{}); + LOG(INFO) << "Operator pushed, should not wait."; + engine->WaitForVar(var); + LOG(INFO) << "OK, here I am."; + engine->WaitForAll(); + LOG(INFO) << "That was 2 seconds."; + for (auto&& i : oprs) { + engine->DeleteOperator(i); + } + engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); + engine->WaitForAll(); + var = nullptr; + oprs.clear(); + + return 0; +}