Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[engine-refactor] renaming [fixes #40]
Browse files Browse the repository at this point in the history
  • Loading branch information
hotpxl committed Sep 7, 2015
1 parent ad97494 commit 261c757
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 97 deletions.
44 changes: 24 additions & 20 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Engine {
* pointer, that points to an internal data structure of the engine
* itself.
*/
using Variable = engine::Var*;
using VarHandle = engine::Var*;
/*!
* \brief Operator of the engine.
*/
Expand All @@ -68,22 +68,25 @@ class Engine {
* patterns.
* \return The new variable allocated.
*/
virtual Variable NewVar() = 0;
virtual VarHandle NewVariable() = 0;
/*!
* \brief Create a new operator. The returned operator could be saved
* externally so that it could be resued for scheduling.
* \param fn The execution function.
* \param use_vars The variables that current operation will use but not
* mutate.
* \param mutate_vars Teh variables that current operation will mutate.
* \param const_vars The variables that current operation will use but not
* mutate.
* \param mutable_vars The variables that current operation will mutate.
* \return The new operator allocated.
*/
virtual OprHandle NewOperator(AsyncFn fn,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) = 0;
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) = 0;
/*!
* \brief Delete the given operator.
* \param op The operator to delete.
*
* The delete will not happen immediately, but will wait until all the
* operations using this operator are completed.
*/
virtual void DeleteOperator(OprHandle op) = 0;
/*!
Expand All @@ -96,44 +99,45 @@ class Engine {
* \brief Push an synchronous operation to the engine.
* \param exec_fun Execution function that executes the operation.
* \param exec_ctx Execution context.
* \param use_vars The variables that current operation will use but not
* mutate.
* \param mutate_vars The variables that current operation will mutate.
* \param const_vars The variables that current operation will use but not
* mutate.
* \param mutable_vars The variables that current operation will mutate.
*/
virtual void Push(Fn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) = 0;
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) = 0;
/*!
* \brief Push an asynchronous operation to the engine.
* \param exec_fun Execution function, this function takes a parameter
* on_complete that must be called when the execution
* completes.
* \param exec_ctx Execution context.
* \param use_vars The variables that current operation will use but not
* mutate.
* \param mutate_vars The variables that current operation will mutate.
* \param const_vars The variables that current operation will use but not
* mutate.
* \param mutable_vars The variables that current operation will mutate.
*/
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) = 0;
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) = 0;
/*!
* \brief Schedule the delete of a variable.
*
* The delete will not happen immediately, but will wait until all the
* operations depending on var is completed.
* operations depending on var are completed.
*
* \param delete_fun A function that will be called after the variable is
* deleted.
* \param exec_ctx Execution context.
* \param var The variable to be deleted.
*/
virtual void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) = 0;
virtual void DeleteVariable(Fn delete_fun, Context exec_ctx,
VarHandle var) = 0;
/*!
* \brief Wait for variable.
* \param var The variable we should wait for, this function returns when all
* the operations related to var has been completed.
*/
virtual void WaitForVar(Variable var) = 0;
virtual void WaitForVar(VarHandle var) = 0;
/*!
* \brief Wait until all the activity of engine finishes.
*/
Expand Down
14 changes: 7 additions & 7 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class NArray {
Engine::Get()->WaitForVar(ptr_->var);
}
/*! \return the associated variable of the narray.*/
inline Engine::Variable var() const {
inline Engine::VarHandle var() const {
return ptr_->var;
}
/*!
Expand Down Expand Up @@ -207,7 +207,7 @@ class NArray {
/*! \brief storage handlefrom storage engine */
Storage::Handle shandle;
/*! \brief variable from engine */
Engine::Variable var;
Engine::VarHandle var;
/*!
* \brief if this is true, this means the data do not come
* from Storage, and do not need to be freed
Expand All @@ -217,21 +217,21 @@ class NArray {
bool delay_alloc;
/*! \brief default cosntructor */
Chunk() : static_data(true), delay_alloc(false) {
var = Engine::Get()->NewVar();
var = Engine::Get()->NewVariable();
}
/*! \brief construct from static data */
Chunk(const TBlob &data, int dev_id)
: static_data(true),
delay_alloc(false) {
var = Engine::Get()->NewVar();
var = Engine::Get()->NewVariable();
shandle.ctx = Context(data.dev_mask_, dev_id);
shandle.dptr = data.dptr_;
shandle.size = data.shape_.Size() * sizeof(real_t);
}
/*! \brief construct a new chunk */
Chunk(uint64_t size, Context ctx, bool delay_alloc_)
: static_data(false), delay_alloc(true) {
var = Engine::Get()->NewVar();
var = Engine::Get()->NewVariable();
shandle.size = size * sizeof(real_t);
shandle.ctx = ctx;
if (!delay_alloc_) this->CheckAndAlloc();
Expand All @@ -246,11 +246,11 @@ class NArray {
/*! \brief destructor */
~Chunk() {
if (static_data) {
Engine::Get()->PushDelete([](RunContext s) {}, shandle.ctx, var);
Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var);
} else {
CHECK(!delay_alloc) << "deleted before allocation";
Storage::Handle h = this->shandle;
Engine::Get()->PushDelete([h](RunContext s) {
Engine::Get()->DeleteVariable([h](RunContext s) {
Storage::Get()->Free(h);
}, shandle.ctx, var);
}
Expand Down
19 changes: 10 additions & 9 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace mxnet {
namespace engine {

NaiveEngine::Variable NaiveEngine::NewVar() { return nullptr; }
NaiveEngine::VarHandle NaiveEngine::NewVariable() { return nullptr; }

NaiveEngine::NaiveEngine() {
#if MXNET_USE_CUDA
Expand All @@ -23,8 +23,8 @@ NaiveEngine::~NaiveEngine() {
}

NaiveEngine::OprHandle NaiveEngine::NewOperator(AsyncFn,
std::vector<Variable> const&,
std::vector<Variable> const&) {
std::vector<VarHandle> const&,
std::vector<VarHandle> const&) {
LOG(FATAL) << "Not implemented";
return nullptr;
}
Expand All @@ -34,8 +34,8 @@ void NaiveEngine::DeleteOperator(OprHandle) { LOG(FATAL) << "Not implemented"; }
void NaiveEngine::Push(OprHandle, Context) { LOG(FATAL) << "Not implemented"; }

void NaiveEngine::Push(Fn exec_fun, Context exec_ctx,
std::vector<Variable> const&,
std::vector<Variable> const&) {
std::vector<VarHandle> const&,
std::vector<VarHandle> const&) {
if (exec_ctx.dev_mask == gpu::kDevMask) {
#if MXNET_USE_CUDA
mshadow::SetDevice<gpu>(exec_ctx.dev_id);
Expand All @@ -50,16 +50,17 @@ void NaiveEngine::Push(Fn exec_fun, Context exec_ctx,
}
}

void NaiveEngine::PushAsync(AsyncFn, Context, std::vector<Variable> const&,
std::vector<Variable> const&) {
void NaiveEngine::PushAsync(AsyncFn, Context, std::vector<VarHandle> const&,
std::vector<VarHandle> const&) {
LOG(FATAL) << "Not implemented";
}

void NaiveEngine::PushDelete(Fn delete_fun, Context exec_ctx, Variable var) {
void NaiveEngine::DeleteVariable(Fn delete_fun, Context exec_ctx,
VarHandle var) {
this->Push(delete_fun, exec_ctx, {}, {var});
}

void NaiveEngine::WaitForVar(Variable) {}
void NaiveEngine::WaitForVar(VarHandle) {}

void NaiveEngine::WaitForAll() {}

Expand Down
18 changes: 9 additions & 9 deletions src/engine/naive_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ class NaiveEngine final : public Engine {
public:
NaiveEngine();
~NaiveEngine();
Variable NewVar() override;
OprHandle NewOperator(AsyncFn fn, std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override;
VarHandle NewVariable() override;
OprHandle NewOperator(AsyncFn fn, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) override;
void DeleteOperator(OprHandle op) override;
void Push(OprHandle op, Context exec_ctx) override;
void Push(Fn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override;
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) override;
void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override;
void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) override;
void WaitForVar(Variable var) override;
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) override;
void DeleteVariable(Fn delete_fun, Context exec_ctx, VarHandle var) override;
void WaitForVar(VarHandle var) override;
void WaitForAll() override;

private:
Expand Down
Loading

0 comments on commit 261c757

Please sign in to comment.