Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #83 from tqchen/master
Browse files Browse the repository at this point in the history
Add resource manager and random
  • Loading branch information
antinucleon committed Sep 16, 2015
2 parents 81318a1 + 46e776b commit bc8ffb2
Show file tree
Hide file tree
Showing 23 changed files with 655 additions and 109 deletions.
10 changes: 10 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ typedef void *DataIterHandle;
* \return error info
*/
MXNET_DLL const char *MXGetLastError();

//-------------------------------------
// Part 0: Global State setups
//-------------------------------------
/*!
* \brief Seed the global random number generators in mxnet.
* \param seed the random number seed.
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXRandomSeed(int seed);
//-------------------------------------
// Part 1: NDArray creation and deletion
//-------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <dmlc/base.h>
#if DMLC_USE_CXX11
#include <memory>
#include <functional>
#endif
#include <vector>
Expand Down Expand Up @@ -154,6 +155,15 @@ class Engine {
* \return Engine singleton.
*/
static Engine* Get();
/*!
* \brief Get shared pointer reference to engine singleton.
* Most user should not call this function.
* This function is called by another singleton X who requires
* engine to be destructed after X.
*
* \return A shared pointer to Engine singleton.
*/
static std::shared_ptr<Engine> _GetSharedRef();
/*!
* \brief Push an synchronous operation to the engine.
* \param exec_fn Execution function that executes the operation.
Expand Down
58 changes: 37 additions & 21 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ class NDArray {
ret.shape_ = shape;
return ret;
}
/*!
* \brief Allocate the space if it is delayed allocated.
* This is an internal function used by system that normal user should not use
*/
inline void CheckAndAlloc() const {
ptr_->CheckAndAlloc();
}

private:
/*! \brief the real data chunk that backs NDArray */
Expand Down Expand Up @@ -299,16 +306,6 @@ class NDArray {
TShape shape_;
/*! \brief offset in chunk */
size_t offset_;

// add friend to helper functions
friend void CopyFromTo(const NDArray &from, NDArray *to);
template<typename OP>
friend void BinaryOp(const NDArray &lhs, const NDArray &rhs, NDArray *out);
template<typename OP>
friend void UnaryOp(const NDArray &lhs, const NDArray &rhs, NDArray *out);
template<typename OP, bool reverse>
friend void ScalarOp(const NDArray &lhs, const real_t &rhs, NDArray *out);
friend void SetValueOp(const real_t &rhs, NDArray *out);
};

/*!
Expand Down Expand Up @@ -380,6 +377,27 @@ NDArray operator/(const NDArray &lhs, const NDArray &rhs);
*/
NDArray operator/(const NDArray &lhs, const real_t &rhs);

/*!
* \brief Seed the random number generator.
* \param seed the seed to set to global random number generators.
*/
void RandomSeed(uint32_t seed);
/*!
* \brief Sample uniform distribution for each elements of out.
* \param begin lower bound of distribution.
* \param end upper bound of distribution.
* \param out output NDArray.
*/
void SampleUniform(real_t begin, real_t end, NDArray *out);

/*!
* \brief Sample gaussian distribution for each elements of out.
* \param mu mean of gaussian distribution.
* \param sigma standard deviation of gaussian distribution.
* \param out output NDArray.
*/
void SampleGaussian(real_t mu, real_t sigma, NDArray *out);

//--------------------------------------------------------------
// The following part are API Registration of NDArray functions.
//--------------------------------------------------------------
Expand Down Expand Up @@ -430,14 +448,12 @@ struct NDArrayFunctionReg
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(void fsetvalue(const real_t &rhs,
NDArray *out)) {
body = [fsetvalue] (NDArray **used_vars,
real_t *s, NDArray **mutate_vars) {
NDArray *out)) {
body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) {
fsetvalue(s[0], mutate_vars[0]);
};
num_mutate_vars = 1; num_scalars = 1;
// type_mask = kNDArrayArgBeforeScalar;
this->add_argument("rhs", "real_t", "Right operand to the function.");
this->add_argument("src", "real_t", "Source input to the function.");
return *this;
}
/*!
Expand All @@ -447,8 +463,8 @@ struct NDArrayFunctionReg
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(void fbinary(const NDArray &lhs,
const NDArray &rhs,
NDArray *out)) {
const NDArray &rhs,
NDArray *out)) {
body = [fbinary] (NDArray **used_vars,
real_t *s, NDArray **mutate_vars) {
fbinary(*used_vars[0], *used_vars[1], mutate_vars[0]);
Expand All @@ -466,10 +482,10 @@ struct NDArrayFunctionReg
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(void fscalar(const NDArray &lhs,
const real_t &rhs,
NDArray *out)) {
const real_t &rhs,
NDArray *out)) {
body = [fscalar] (NDArray **used_vars,
real_t *s, NDArray **mutate_vars) {
real_t *s, NDArray **mutate_vars) {
fscalar(*used_vars[0], s[0], mutate_vars[0]);
};
num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1;
Expand All @@ -485,7 +501,7 @@ struct NDArrayFunctionReg
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(void funary(const NDArray &src,
NDArray *out)) {
NDArray *out)) {
body = [funary] (NDArray **used_vars,
real_t *s, NDArray **mutate_vars) {
funary(*used_vars[0], mutate_vars[0]);
Expand Down
35 changes: 33 additions & 2 deletions include/mxnet/resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,18 @@ struct Resource {
void *ptr_;
/*!
* \brief Get random number generator.
* \param The stream to use in the random number generator.
* \return the mshadow random number generator requested.
* \tparam xpu the device type of random number generator.
*/
template<typename xpu>
inline mshadow::Random<xpu>* get_random() const {
inline mshadow::Random<xpu>* get_random(
mshadow::Stream<xpu> *stream) const {
CHECK_EQ(req.type, ResourceRequest::kRandom);
return static_cast<mshadow::Random<xpu>*>(ptr_);
mshadow::Random<xpu> *ret =
static_cast<mshadow::Random<xpu>*>(ptr_);
ret->set_stream(stream);
return ret;
}
/*!
* \brief Get space requested as mshadow Tensor.
Expand All @@ -81,5 +86,31 @@ struct Resource {
static_cast<real_t*>(ptr_), shape, shape[ndim - 1], stream);
}
};

/*! \brief Global resource manager */
class ResourceManager {
public:
/*!
* \brief Get resource of requested type.
* \param ctx the context of the request.
* \param req the resource request.
* \return the requested resource.
* \note The returned resource's ownership is
* still hold by the manager singleton.
*
*/
virtual Resource Request(Context ctx, const ResourceRequest &req) = 0;
/*!
* \brief Seed all the allocated random numbers.
* \param seed the seed to the random number generators on all devices.
*/
virtual void SeedRandom(uint32_t seed) = 0;
/*! \brief virtual destructor */
virtual ~ResourceManager() {}
/*!
* \return Resource manager singleton.
*/
static ResourceManager *Get();
};
} // namespace mxnet
#endif // MXNET_RESOURCE_H_
9 changes: 9 additions & 0 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ class Storage {
* \return Storage singleton.
*/
static Storage* Get();
/*!
* \brief Get shared pointer reference to engine singleton.
* Most user should not call this function.
* This function is called by another singleton X who requires
* Storage to be destructed after X.
*
* \return A shared pointer to Storage singleton.
*/
static std::shared_ptr<Storage> _GetSharedRef();

private:
/*!
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
from . import io
# use mx.nd as short for mx.ndarray
from . import ndarray as nd
from . import random

__version__ = "0.1.0"
8 changes: 6 additions & 2 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ def __init__(self, device_type, device_id=0):
device_id : int (default=0)
the device id of the device, needed for GPU
"""
self.device_mask = Context.devtype2mask[device_type]
self.device_id = device_id
if isinstance(device_type, Context):
self.device_mask = device_type.device_mask
self.device_id = device_type.device_id
else:
self.device_mask = Context.devtype2mask[device_type]
self.device_id = device_id
self._old_ctx = None

@property
Expand Down
4 changes: 1 addition & 3 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,7 @@ def zeros(shape, ctx=None):
----------
shape : tuple
shape of the NDArray.
ctx : Context, optional
ctx : Context, optional.
The context of the NDArray, default to current default context.
Returns
Expand All @@ -360,7 +359,6 @@ def ones(shape, ctx=None):
----------
shape : tuple
shape of the NDArray.
ctx : Context, optional
The context of the NDArray, default to current default context.
Expand Down
99 changes: 99 additions & 0 deletions python/mxnet/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# coding: utf-8
# pylint: disable=no-member, protected-access
"""Random Number interface of mxnet."""
from __future__ import absolute_import

import ctypes
from .base import _LIB, check_call
from .ndarray import NDArray, empty


def uniform(low, high, shape=None, ctx=None, out=None):
"""Generate uniform distribution in [low, high) with shape.
Parameters
----------
low : float
The lower bound of distribution.
high : float
The upper bound of distribution.
shape : tuple, optional
Output shape of the NDArray generated.
ctx : Context, optional
Context of output NDArray, will use default context if not specified.
out : NDArray, optional
Output place holder
Returns
-------
out : NDArray
The result NDArray with generated result.
"""
if out is not None:
if shape is not None or ctx is not None:
raise ValueError('shape and ctx is not needed when out is specified')
else:
if shape is None:
raise ValueError('shape is required when out is not specified')
if isinstance(shape, int):
shape = (shape,)
out = empty(shape, ctx)
return NDArray._random_uniform(low, high, out=out)


def normal(mean, stdvar, shape=None, ctx=None, out=None):
"""Generate normal(Gaussian) distribution N(mean, stdvar^2) with shape.
Parameters
----------
mean : float
The mean of the normal distribution.
stdvar : float
The standard deviation of normal distribution.
shape : tuple, optional
Output shape of the NDArray generated.
ctx : Context, optional
Context of output NDArray, will use default context if not specified.
out : NDArray, optional
Output place holder
Returns
-------
out : NDArray
The result NDArray with generated result.
"""
if out is not None:
if shape is not None or ctx is not None:
raise ValueError('shape and ctx is not needed when out is specified')
else:
if shape is None:
raise ValueError('shape is required when out is not specified')
if isinstance(shape, int):
shape = (shape,)
out = empty(shape, ctx)
return NDArray._random_gaussian(mean, stdvar, out=out)


def seed(seed_state):
"""Seed the random number generators in mxnet.
This seed will affect behavior of functions in this module,
as well as results from executors that contains Random number
such as Dropout operators.
Parameters
----------
seed_state : int
The random number seed to set to all devices.
Notes
-----
The random number generator of mxnet is by default device specific.
This means if you set the same seed, the random number sequence
generated from GPU0 can be different from CPU.
"""
if not isinstance(seed_state, int):
raise ValueError('sd must be int')
seed_state = ctypes.c_int(int(seed_state))
check_call(_LIB.MXRandomSeed(seed_state))

6 changes: 6 additions & 0 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
}

// NOTE: return value is added in API_END
int MXRandomSeed(int seed) {
API_BEGIN();
mxnet::RandomSeed(seed);
API_END();
}

int MXNDArrayCreateNone(NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray();
Expand Down
Loading

0 comments on commit bc8ffb2

Please sign in to comment.