diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 8b7359e42991..65fca4ee46cb 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -57,7 +57,7 @@ class Engine { * pointer, that points to an internal data structure of the engine * itself. */ - using Variable = engine::Var*; + using VarHandle = engine::Var*; /*! * \brief Operator of the engine. */ @@ -68,22 +68,25 @@ class Engine { * patterns. * \return The new variable allocated. */ - virtual Variable NewVar() = 0; + virtual VarHandle NewVariable() = 0; /*! * \brief Create a new operator. The returned operator could be saved * externally so that it could be resued for scheduling. * \param fn The execution function. - * \param use_vars The variables that current operation will use but not - * mutate. - * \param mutate_vars Teh variables that current operation will mutate. + * \param const_vars The variables that current operation will use but not + * mutate. + * \param mutable_vars The variables that current operation will mutate. * \return The new operator allocated. */ virtual OprHandle NewOperator(AsyncFn fn, - std::vector const& use_vars, - std::vector const& mutate_vars) = 0; + std::vector const& const_vars, + std::vector const& mutable_vars) = 0; /*! * \brief Delete the given operator. * \param op The operator to delete. + * + * The delete will not happen immediately, but will wait until all the + * operations using this operator are completed. */ virtual void DeleteOperator(OprHandle op) = 0; /*! @@ -96,44 +99,45 @@ class Engine { * \brief Push an synchronous operation to the engine. * \param exec_fun Execution function that executes the operation. * \param exec_ctx Execution context. - * \param use_vars The variables that current operation will use but not - * mutate. - * \param mutate_vars The variables that current operation will mutate. + * \param const_vars The variables that current operation will use but not + * mutate. + * \param mutable_vars The variables that current operation will mutate. */ virtual void Push(Fn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) = 0; + std::vector const& const_vars, + std::vector const& mutable_vars) = 0; /*! * \brief Push an asynchronous operation to the engine. * \param exec_fun Execution function, this function takes a parameter * on_complete that must be called when the execution * completes. * \param exec_ctx Execution context. - * \param use_vars The variables that current operation will use but not - * mutate. - * \param mutate_vars The variables that current operation will mutate. + * \param const_vars The variables that current operation will use but not + * mutate. + * \param mutable_vars The variables that current operation will mutate. */ virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) = 0; + std::vector const& const_vars, + std::vector const& mutable_vars) = 0; /*! * \brief Schedule the delete of a variable. * * The delete will not happen immediately, but will wait until all the - * operations depending on var is completed. + * operations depending on var are completed. * * \param delete_fun A function that will be called after the variable is * deleted. * \param exec_ctx Execution context. * \param var The variable to be deleted. */ - virtual void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) = 0; + virtual void DeleteVariable(Fn delete_fun, Context exec_ctx, + VarHandle var) = 0; /*! * \brief Wait for variable. * \param var The variable we should wait for, this function returns when all * the operations related to var has been completed. */ - virtual void WaitForVar(Variable var) = 0; + virtual void WaitForVar(VarHandle var) = 0; /*! * \brief Wait until all the activity of engine finishes. */ diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 973b60b049c5..427d2c65b5ed 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -79,7 +79,7 @@ class NArray { Engine::Get()->WaitForVar(ptr_->var); } /*! \return the associated variable of the narray.*/ - inline Engine::Variable var() const { + inline Engine::VarHandle var() const { return ptr_->var; } /*! @@ -207,7 +207,7 @@ class NArray { /*! \brief storage handlefrom storage engine */ Storage::Handle shandle; /*! \brief variable from engine */ - Engine::Variable var; + Engine::VarHandle var; /*! * \brief if this is true, this means the data do not come * from Storage, and do not need to be freed @@ -217,13 +217,13 @@ class NArray { bool delay_alloc; /*! \brief default cosntructor */ Chunk() : static_data(true), delay_alloc(false) { - var = Engine::Get()->NewVar(); + var = Engine::Get()->NewVariable(); } /*! \brief construct from static data */ Chunk(const TBlob &data, int dev_id) : static_data(true), delay_alloc(false) { - var = Engine::Get()->NewVar(); + var = Engine::Get()->NewVariable(); shandle.ctx = Context(data.dev_mask_, dev_id); shandle.dptr = data.dptr_; shandle.size = data.shape_.Size() * sizeof(real_t); @@ -231,7 +231,7 @@ class NArray { /*! \brief construct a new chunk */ Chunk(uint64_t size, Context ctx, bool delay_alloc_) : static_data(false), delay_alloc(true) { - var = Engine::Get()->NewVar(); + var = Engine::Get()->NewVariable(); shandle.size = size * sizeof(real_t); shandle.ctx = ctx; if (!delay_alloc_) this->CheckAndAlloc(); @@ -246,11 +246,11 @@ class NArray { /*! \brief destructor */ ~Chunk() { if (static_data) { - Engine::Get()->PushDelete([](RunContext s) {}, shandle.ctx, var); + Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var); } else { CHECK(!delay_alloc) << "deleted before allocation"; Storage::Handle h = this->shandle; - Engine::Get()->PushDelete([h](RunContext s) { + Engine::Get()->DeleteVariable([h](RunContext s) { Storage::Get()->Free(h); }, shandle.ctx, var); } diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index dcd6807fe24b..b9e7ffdb54de 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -7,7 +7,7 @@ namespace mxnet { namespace engine { -NaiveEngine::Variable NaiveEngine::NewVar() { return nullptr; } +NaiveEngine::VarHandle NaiveEngine::NewVariable() { return nullptr; } NaiveEngine::NaiveEngine() { #if MXNET_USE_CUDA @@ -23,8 +23,8 @@ NaiveEngine::~NaiveEngine() { } NaiveEngine::OprHandle NaiveEngine::NewOperator(AsyncFn, - std::vector const&, - std::vector const&) { + std::vector const&, + std::vector const&) { LOG(FATAL) << "Not implemented"; return nullptr; } @@ -34,8 +34,8 @@ void NaiveEngine::DeleteOperator(OprHandle) { LOG(FATAL) << "Not implemented"; } void NaiveEngine::Push(OprHandle, Context) { LOG(FATAL) << "Not implemented"; } void NaiveEngine::Push(Fn exec_fun, Context exec_ctx, - std::vector const&, - std::vector const&) { + std::vector const&, + std::vector const&) { if (exec_ctx.dev_mask == gpu::kDevMask) { #if MXNET_USE_CUDA mshadow::SetDevice(exec_ctx.dev_id); @@ -50,16 +50,17 @@ void NaiveEngine::Push(Fn exec_fun, Context exec_ctx, } } -void NaiveEngine::PushAsync(AsyncFn, Context, std::vector const&, - std::vector const&) { +void NaiveEngine::PushAsync(AsyncFn, Context, std::vector const&, + std::vector const&) { LOG(FATAL) << "Not implemented"; } -void NaiveEngine::PushDelete(Fn delete_fun, Context exec_ctx, Variable var) { +void NaiveEngine::DeleteVariable(Fn delete_fun, Context exec_ctx, + VarHandle var) { this->Push(delete_fun, exec_ctx, {}, {var}); } -void NaiveEngine::WaitForVar(Variable) {} +void NaiveEngine::WaitForVar(VarHandle) {} void NaiveEngine::WaitForAll() {} diff --git a/src/engine/naive_engine.h b/src/engine/naive_engine.h index c172beb6b0b6..991010cbb641 100644 --- a/src/engine/naive_engine.h +++ b/src/engine/naive_engine.h @@ -15,19 +15,19 @@ class NaiveEngine final : public Engine { public: NaiveEngine(); ~NaiveEngine(); - Variable NewVar() override; - OprHandle NewOperator(AsyncFn fn, std::vector const& use_vars, - std::vector const& mutate_vars) override; + VarHandle NewVariable() override; + OprHandle NewOperator(AsyncFn fn, std::vector const& const_vars, + std::vector const& mutable_vars) override; void DeleteOperator(OprHandle op) override; void Push(OprHandle op, Context exec_ctx) override; void Push(Fn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) override; + std::vector const& const_vars, + std::vector const& mutable_vars) override; void PushAsync(AsyncFn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) override; - void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) override; - void WaitForVar(Variable var) override; + std::vector const& const_vars, + std::vector const& mutable_vars) override; + void DeleteVariable(Fn delete_fun, Context exec_ctx, VarHandle var) override; + void WaitForVar(VarHandle var) override; void WaitForAll() override; private: diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 3df0e4348d7d..3defb5385daf 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -132,38 +132,38 @@ ThreadedEngine::~ThreadedEngine() noexcept(false) { task_queue_.SignalForKill(); } -ThreadedVar* ThreadedEngine::NewVar() { +ThreadedVar* ThreadedEngine::NewVariable() { auto ret = ThreadedVar::New(VersionedVarBlock::New()); return ret; } ThreadedOpr* ThreadedEngine::NewOperator( - ThreadedEngine::AsyncFn fn, std::vector const& use_vars, - std::vector const& mutate_vars) { + ThreadedEngine::AsyncFn fn, std::vector const& const_vars, + std::vector const& mutable_vars) { auto ret = ThreadedOpr::New(); ret->fn = fn; - ret->use_vars.resize(use_vars.size()); - ret->mutate_vars.resize(mutate_vars.size()); - std::transform(use_vars.begin(), use_vars.end(), ret->use_vars.begin(), + ret->const_vars.resize(const_vars.size()); + ret->mutable_vars.resize(mutable_vars.size()); + std::transform(const_vars.begin(), const_vars.end(), ret->const_vars.begin(), ThreadedVar::CastFromBase); - std::transform(mutate_vars.begin(), mutate_vars.end(), - ret->mutate_vars.begin(), ThreadedVar::CastFromBase); + std::transform(mutable_vars.begin(), mutable_vars.end(), + ret->mutable_vars.begin(), ThreadedVar::CastFromBase); #if ENGINE_DEBUG // Check for duplicates. - auto use = use_vars; - auto mutate = mutate_vars; + auto use = const_vars; + auto mutate = mutable_vars; auto use_size = use.size(); auto mutate_size = mutate.size(); std::sort(use.begin(), use.end()); std::sort(mutate.begin(), mutate.end()); for (std::size_t i = 0; i < use_size; ++i) { if (i != 0 && use.at(i) == use.at(i - 1)) { - LOG(FATAL) << "duplicate items found in `use_vars`"; + LOG(FATAL) << "duplicate items found in `const_vars`"; } } for (std::size_t i = 0; i < mutate_size; ++i) { if (i != 0 && mutate.at(i) == mutate.at(i - 1)) { - LOG(FATAL) << "duplicate items found in `mutate_vars`"; + LOG(FATAL) << "duplicate items found in `mutable_vars`"; } } std::size_t j = 0; @@ -176,7 +176,7 @@ ThreadedOpr* ThreadedEngine::NewOperator( } if (mutate.at(j) == use.at(i)) { LOG(FATAL) - << "duplicate items found between `use_vars` and `mutate_vars`"; + << "duplicate items found between `const_vars` and `mutable_vars`"; } } #endif // ENGINE_DEBUG @@ -185,43 +185,43 @@ ThreadedOpr* ThreadedEngine::NewOperator( void ThreadedEngine::DeleteOperator(OprHandle op) { auto&& threaded_opr = ThreadedOpr::CastFromBase(op); - std::vector deps{}; - deps.reserve(threaded_opr->use_vars.size() + - threaded_opr->mutate_vars.size()); - deps.insert(deps.end(), threaded_opr->use_vars.begin(), - threaded_opr->use_vars.end()); - deps.insert(deps.end(), threaded_opr->mutate_vars.begin(), - threaded_opr->mutate_vars.end()); + std::vector deps{}; + deps.reserve(threaded_opr->const_vars.size() + + threaded_opr->mutable_vars.size()); + deps.insert(deps.end(), threaded_opr->const_vars.begin(), + threaded_opr->const_vars.end()); + deps.insert(deps.end(), threaded_opr->mutable_vars.begin(), + threaded_opr->mutable_vars.end()); auto&& func = [threaded_opr](RunContext) { ThreadedOpr::Delete(threaded_opr); }; Push(func, Context{}, {}, deps); } void ThreadedEngine::Push(Fn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) { + std::vector const& const_vars, + std::vector const& mutable_vars) { auto f = [exec_fun](RunContext ctx, Callback on_complete) { exec_fun(ctx); on_complete(); }; - PushAsync(f, exec_ctx, use_vars, mutate_vars); + PushAsync(f, exec_ctx, const_vars, mutable_vars); } void ThreadedEngine::Push(OprHandle op, Context exec_ctx) { auto&& threaded_opr = ThreadedOpr::CastFromBase(op); auto&& opr_block = OprBlock::New(); opr_block->opr = threaded_opr; - opr_block->wait.store(threaded_opr->use_vars.size() + - threaded_opr->mutate_vars.size() + 1); + opr_block->wait.store(threaded_opr->const_vars.size() + + threaded_opr->mutable_vars.size() + 1); opr_block->ctx = exec_ctx; opr_block->rctx = RunContext{nullptr}; ++pending_; // Add read dependencies. - for (auto&& i : threaded_opr->use_vars) { + for (auto&& i : threaded_opr->const_vars) { i->AppendReadDependency(opr_block); } // Add write dependencies. - for (auto&& i : threaded_opr->mutate_vars) { + for (auto&& i : threaded_opr->mutable_vars) { i->AppendWriteDependency(opr_block); } if (--opr_block->wait == 0) { @@ -230,14 +230,15 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx) { } void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) { - auto&& opr = NewOperator(fn, use_vars, mutate_vars); + std::vector const& const_vars, + std::vector const& mutable_vars) { + auto&& opr = NewOperator(fn, const_vars, mutable_vars); opr->temporary = true; Push(opr, exec_ctx); } -void ThreadedEngine::PushDelete(Fn delete_fn, Context exec_ctx, Variable var) { +void ThreadedEngine::DeleteVariable(Fn delete_fn, Context exec_ctx, + VarHandle var) { auto&& threaded_var = ThreadedVar::CastFromBase(var); auto&& func = [delete_fn, threaded_var](RunContext ctx) { /*! @@ -250,7 +251,7 @@ void ThreadedEngine::PushDelete(Fn delete_fn, Context exec_ctx, Variable var) { Push(func, exec_ctx, {}, {var}); } -void ThreadedEngine::WaitForVar(Variable var) { +void ThreadedEngine::WaitForVar(VarHandle var) { std::unique_lock lock{finished_m_}; std::atomic done{false}; auto&& callback = [this, &done](RunContext) { @@ -271,13 +272,13 @@ void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { /*! * Mark complete for read variables. */ - for (auto&& i : threaded_opr->use_vars) { + for (auto&& i : threaded_opr->const_vars) { i->CompleteReadDependency([this](OprBlock* opr) { task_queue_.Push(opr); }); } /*! * Mark complete for write variables. */ - for (auto&& i : threaded_opr->mutate_vars) { + for (auto&& i : threaded_opr->mutable_vars) { bool to_delete = i->CompleteWriteDependency( [this](OprBlock* opr) { task_queue_.Push(opr); }); if (to_delete) { diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index d13cc64026a1..446249cebb99 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -103,8 +103,8 @@ struct ThreadedOpr final : public Opr, ~ThreadedOpr() { LOG(INFO) << __func__ << " " << --counter; } #endif // ENGINE_DEBUG Engine::AsyncFn fn; - std::vector use_vars; - std::vector mutate_vars; + std::vector const_vars; + std::vector mutable_vars; bool temporary{false}; static ThreadedOpr* CastFromBase(Opr* ptr); @@ -123,19 +123,19 @@ class ThreadedEngine final : public Engine { /*! * \brief Overriding methods. */ - ThreadedVar* NewVar() override; - ThreadedOpr* NewOperator(AsyncFn fn, std::vector const& use_vars, - std::vector const& mutate_vars) override; + ThreadedVar* NewVariable() override; + ThreadedOpr* NewOperator(AsyncFn fn, std::vector const& const_vars, + std::vector const& mutable_vars) override; void DeleteOperator(OprHandle op) override; void Push(OprHandle op, Context exec_ctx) override; void Push(Fn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) override; + std::vector const& const_vars, + std::vector const& mutable_vars) override; void PushAsync(AsyncFn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) override; - void PushDelete(Fn delete_fn, Context exec_ctx, Variable var) override; - void WaitForVar(Variable var) override; + std::vector const& const_vars, + std::vector const& mutable_vars) override; + void DeleteVariable(Fn delete_fn, Context exec_ctx, VarHandle var) override; + void WaitForVar(VarHandle var) override; void WaitForAll() override; /*! * \brief Callback on operation completion. diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index 2dccdc95cd51..af2160415e8d 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -93,9 +93,9 @@ class GraphExecutor : public Executor { // execution function for Engine::Fn exec_fun; // variables to read from - std::vector use_vars; + std::vector use_vars; // variables to mutate - std::vector mutate_vars; + std::vector mutate_vars; // constructor OpExecEntry() : exec_fun(nullptr) {} }; diff --git a/tests/test_threaded_engine.cc b/tests/test_threaded_engine.cc index 35e4e563f4cb..9e9c197770ac 100644 --- a/tests/test_threaded_engine.cc +++ b/tests/test_threaded_engine.cc @@ -14,7 +14,7 @@ void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); } int main() { auto&& engine = mxnet::Engine::Get(); - auto&& var = engine->NewVar(); + auto&& var = engine->NewVariable(); std::vector oprs; // Test #1 @@ -39,7 +39,7 @@ int main() { engine->WaitForAll(); printf("============= Test #2 ==============\n"); - var = engine->NewVar(); + var = engine->NewVariable(); oprs.clear(); for (int i = 0; i < 10; ++i) { oprs.push_back(engine->NewOperator( @@ -59,14 +59,14 @@ int main() { engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); printf("============= Test #3 ==============\n"); - var = engine->NewVar(); + var = engine->NewVariable(); oprs.clear(); engine->WaitForVar(var); engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); engine->WaitForAll(); printf("============= Test #4 ==============\n"); - var = engine->NewVar(); + var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( [](mxnet::RunContext ctx, mxnet::Engine::Callback cb) { @@ -86,7 +86,7 @@ int main() { engine->WaitForAll(); printf("============= Test #5 ==============\n"); - var = engine->NewVar(); + var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( [](mxnet::RunContext ctx, mxnet::Engine::Callback cb) {