Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #100 from tqchen/master
Browse files Browse the repository at this point in the history
remove finalize, ignore cuda driver shutdown, add release to pinned …
  • Loading branch information
tqchen committed Sep 19, 2015
2 parents 8d40480 + 84df51a commit dafe1ee
Show file tree
Hide file tree
Showing 23 changed files with 82 additions and 144 deletions.
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS)
endif

.PHONY: clean all test lint doc
.PHONY: clean all test lint doc clean_all

all: lib/libmxnet.a lib/libmxnet.so $(BIN)

Expand Down Expand Up @@ -116,6 +116,8 @@ doxygen:

clean:
$(RM) -r build lib/lib* *~ */*~ */*/*~ */*/*/*~

clean_all: clean
cd $(DMLC_CORE); make clean; cd -

-include build/*.d
Expand Down
17 changes: 11 additions & 6 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# pylint: skip-file
import numpy as np
import mxnet as mx
import copy
import sys
import sys, os
# code to directly use library
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, "../../python/")
sys.path.append("../../tests/python/common")
# import library
import mxnet as mx
import get_data
import time
import numpy as np
import copy


"""
CXXNET Result:
Expand Down Expand Up @@ -163,7 +168,7 @@ def RandomInit(narray):

in_data = mx.nd.empty(data_shape, mx.gpu())
executor = loss.simple_bind(mx.gpu(), data = in_data)
print executor.debug_str()


out_narray = executor.outputs[0]
pred = mx.nd.zeros(out_narray.shape, mx.cpu())
Expand All @@ -176,7 +181,7 @@ def RandomInit(narray):
block = list(zip(grad_narrays, arg_narrays, momentum_narrays))

np.random.seed(0)
# set random weight


for name, narray in inputs.items():
if "weight" in name:
Expand Down
5 changes: 0 additions & 5 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ typedef mshadow::TShape TShape;
/*! \brief storage container type */
typedef mshadow::TBlob TBlob;

/*!
* \brief Finalize and shutdown all related modules of mxnet.
* Call this function at end of program to ensure correct shutdown.
*/
void Finalize();

/*! \brief Context information about the execution enviroment */
struct Context {
Expand Down
5 changes: 0 additions & 5 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ MXNET_DLL const char *MXGetLastError();
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXRandomSeed(int seed);
/*!
* \brief Finalize and shutdown all related modules of mxnet.
* Call this function at end of program to ensure correct shutdown.
*/
MXNET_DLL int MXFinalize();
//-------------------------------------
// Part 1: NDArray creation and deletion
//-------------------------------------
Expand Down
8 changes: 0 additions & 8 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,6 @@ class Engine {
ret.param_ = param;
return ret;
}
// friend function
friend void ::mxnet::Finalize();
/*!
* \brief Idempotent Finalize function.
* This function will signal engine to release all resources.
* It is safe to call this function multiple times.
*/
virtual void Finalize() = 0;
}; // class Engine
#endif // DMLC_USE_CXX11
} // namespace mxnet
Expand Down
10 changes: 0 additions & 10 deletions include/mxnet/resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,6 @@ class ResourceManager {
* \return Resource manager singleton.
*/
static ResourceManager *Get();

protected:
// friend function
friend void ::mxnet::Finalize();
/*!
* \brief Idempotent Finalize function.
* This function will signal resource manager to release all resources.
* It is safe to call this function multiple times.
*/
virtual void Finalize() = 0;
};
} // namespace mxnet
#endif // MXNET_RESOURCE_H_
10 changes: 0 additions & 10 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,6 @@ class Storage {
*/
static std::shared_ptr<Storage> _GetSharedRef();

protected:
// friend function
friend void ::mxnet::Finalize();
/*!
* \brief Idempotent Finalize function.
* This function will signal engine to release all resources.
* It is safe to call this function multiple times.
*/
void Finalize();

private:
/*!
* \brief Hidden constructors.
Expand Down
2 changes: 1 addition & 1 deletion mshadow
13 changes: 0 additions & 13 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,5 @@
# use mx.nd as short for mx.ndarray
from . import ndarray as nd
from . import random
import atexit

__version__ = "0.1.0"

def finalize():
"""Stop all the components in mxnet.
There is no need to call this function.
This function will be automatically called at module exit.
"""
# pylint: disable=protected-access
base.check_call(base._LIB.MXFinalize())
kv._cleanup()

atexit.register(finalize)
2 changes: 2 additions & 0 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .ndarray import NDArray
from .base import _LIB
from .base import check_call, c_array, NDArrayHandle
import atexit

__all__ = ['start', 'init', 'push', 'pull', 'set_updater']

Expand Down Expand Up @@ -218,3 +219,4 @@ def _cleanup():
global _updater_func
_updater_func = None

atexit.register(_cleanup)
6 changes: 0 additions & 6 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,6 @@ int MXRandomSeed(int seed) {
API_END();
}

int MXFinalize() {
API_BEGIN();
mxnet::Finalize();
API_END();
}

int MXNDArrayCreateNone(NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray();
Expand Down
14 changes: 9 additions & 5 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@ class NaiveEngine final : public Engine {
}
// virtual destructor
virtual ~NaiveEngine() {
Finalize();
}

void Finalize() override {
#if MXNET_USE_CUDA
for (size_t i = 0; i < streams_.size(); ++i) {
if (streams_[i] != nullptr) {
mshadow::DeleteStream(streams_[i]);
// Catch exception for CUDA driver shutdown
try {
mshadow::DeleteStream(streams_[i]);
} catch (const dmlc::Error &e) {
std::string what = e.what();
if (what.find("driver shutting down") == std::string::npos) {
LOG(ERROR) << "Ignore Error " << what << " during worker finalization";
}
}
streams_[i] = nullptr;
}
}
Expand Down
11 changes: 10 additions & 1 deletion src/engine/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <mxnet/base.h>
#include <cstddef>
#include <array>
#include <string>
#include <mutex>
#include "../common/cuda_utils.h"

Expand Down Expand Up @@ -118,7 +119,15 @@ void StreamManager<kNumGpus, kStreams>::Finalize() {
for (std::size_t i = 0; i < kNumGpus; ++i) {
if (gpu_cnt_.at(i) != -1) {
for (auto&& j : gpu_streams_.at(i)) {
mshadow::DeleteStream<gpu>(j);
// Catch exception for CUDA driver shutdown
try {
mshadow::DeleteStream<gpu>(j);
} catch (const dmlc::Error &e) {
std::string what = e.what();
if (what.find("driver shutting down") == std::string::npos) {
LOG(ERROR) << "Ignore Error " << what << " during worker finalization";
}
}
}
gpu_cnt_.at(i) = -1;
}
Expand Down
6 changes: 0 additions & 6 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,6 @@ void ThreadedEngine::WaitForAll() {
finished_cv_.wait(lock, [this]() { return pending_.load() == 0; });
}

void ThreadedEngine::Finalize() {
// unlock all threads
pending_.store(0);
finished_cv_.notify_all();
}

inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
// Mark complete for read variables
for (auto&& i : threaded_opr->const_vars) {
Expand Down
1 change: 0 additions & 1 deletion src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ class ThreadedEngine : public Engine {
threaded_opr->fn(run_ctx, callback);
OprBlock::Delete(opr_block);
}
void Finalize() override;

private:
/*!
Expand Down
22 changes: 12 additions & 10 deletions src/engine/threaded_engine_perdevice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <dmlc/concurrency.h>
#include <array>
#include "./threaded_engine.h"
#include "./thread_pool.h"
#include "../common/lazy_alloc_array.h"
Expand Down Expand Up @@ -38,7 +37,9 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
// GPU tasks will be created lazily
}
~ThreadedEnginePerDevice() noexcept(false) {
Finalize();
gpu_normal_workers_.Clear();
gpu_copy_workers_.Clear();
cpu_worker_.reset(nullptr);
}

protected:
Expand All @@ -64,13 +65,6 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
}
}
}
// finalize the internal resources
void Finalize() override {
gpu_normal_workers_.Clear();
gpu_copy_workers_.Clear();
cpu_worker_.reset(nullptr);
ThreadedEngine::Finalize();
}

private:
// working unit for each of the task.
Expand Down Expand Up @@ -145,7 +139,15 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
while (task_queue->Pop(&opr_block)) {
this->ExecuteOprBlock(run_ctx, opr_block);
}
mshadow::DeleteStream<gpu>(stream);
// Catch exception for CUDA driver shutdown
try {
mshadow::DeleteStream<gpu>(stream);
} catch (const dmlc::Error &e) {
std::string what = e.what();
if (what.find("driver shutting down") == std::string::npos) {
LOG(ERROR) << "Ignore Error " << what << " during worker finalization";
}
}
#endif
}
/*!
Expand Down
18 changes: 6 additions & 12 deletions src/engine/threaded_engine_pooled.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ class ThreadedEnginePooled : public ThreadedEngine {
io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {}

~ThreadedEnginePooled() noexcept(false) {
Finalize();
// wait until all the tasks are completed.
// TODO(hotpxl) think if this is the correct thing to do
this->WaitForAll();
streams_.Finalize();
task_queue_.SignalForKill();
io_task_queue_.SignalForKill();
}

protected:
Expand All @@ -39,17 +44,6 @@ class ThreadedEnginePooled : public ThreadedEngine {
DoPushToQueue(opr_block);
}
}
// finalize the internal resources
void Finalize() override {
// wait until all the tasks are completed.
// TODO(hotpxl) think if this is the correct thing to do
this->WaitForAll();
streams_.Finalize();
task_queue_.SignalForKill();
io_task_queue_.SignalForKill();
ThreadedEngine::Finalize();
}


private:
/*! \brief Concurrency for thread pool */
Expand Down
21 changes: 0 additions & 21 deletions src/global.cc

This file was deleted.

29 changes: 12 additions & 17 deletions src/resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,18 @@ class ResourceManagerImpl : public ResourceManager {
Context(cpu::kDevMask, 0), cpu_temp_space_copy_));
}
~ResourceManagerImpl() {
Finalize();
// need explicit delete, before engine get killed
cpu_rand_.reset(nullptr);
cpu_space_.reset(nullptr);
#if MXNET_USE_CUDA
gpu_rand_.Clear();
gpu_space_.Clear();
#endif
if (engine_ref_ != nullptr) {
engine_ref_->WaitForAll();
// release the reference to engine.
engine_ref_ = nullptr;
}
}

// request resources
Expand Down Expand Up @@ -74,22 +85,6 @@ class ResourceManagerImpl : public ResourceManager {
#endif
}

protected:
void Finalize() override {
// need explicit delete, before engine get killed
cpu_rand_.reset(nullptr);
cpu_space_.reset(nullptr);
#if MXNET_USE_CUDA
gpu_rand_.Clear();
gpu_space_.Clear();
#endif
if (engine_ref_ != nullptr) {
engine_ref_->WaitForAll();
// release the reference to engine.
engine_ref_ = nullptr;
}
}

private:
/*! \brief Maximum number of GPUs */
static constexpr std::size_t kMaxNumGPUs = 16;
Expand Down
Loading

0 comments on commit dafe1ee

Please sign in to comment.