Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Custom Operator Random Number Generator Support #17762

Merged
merged 28 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ if(USE_CUDA)
message("-- CUDA: Using the following NVCC architecture flags ${CUDA_ARCH_FLAGS}")
set(arch_code_list)
foreach(arch_str ${CUDA_ARCH_FLAGS})
if((arch_str MATCHES ".*sm_[0-9]+"))
if((arch_str MATCHES ".*sm_[0-9]+"))
string( REGEX REPLACE ".*sm_([0-9]+)" "\\1" arch_code ${arch_str} )
list(APPEND arch_code_list ${arch_code})
endif()
Expand Down Expand Up @@ -719,7 +719,7 @@ elseif(MSVC)
"$<$<COMPILE_LANGUAGE:CUDA>:--gpu-code=sm_${arch},compute_${arch}>"
)
target_compile_options(
mxnet_${arch}
mxnet_${arch}
PRIVATE "$<$<AND:$<CONFIG:DEBUG>,$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=-MTd -Gy /bigobj>")
target_compile_options(
mxnet_${arch}
Expand Down Expand Up @@ -748,26 +748,21 @@ elseif(MSVC)
endif()
endif()

# extension libraries (custom operators, custom subgraphs) are built by default
add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc)
add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc)
target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
if (USE_CUDA)
if(USE_CUDA)
add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu)
target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
endif()
if(UNIX)
target_compile_options(customop_lib PUBLIC -shared)
target_compile_options(subgraph_lib PUBLIC -shared)
if (USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC -shared)
endif()
elseif(MSVC)
if(MSVC)
target_compile_options(customop_lib PUBLIC /LD)
target_compile_options(subgraph_lib PUBLIC /LD)
set_target_properties(customop_lib PROPERTIES PREFIX "lib")
set_target_properties(subgraph_lib PROPERTIES PREFIX "lib")
if (USE_CUDA)
if(USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fPIC>")
set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib")
endif()
Expand Down
90 changes: 83 additions & 7 deletions example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
/*!
* Copyright (c) 2020 by Contributors
* \file relu_lib.cu
* \brief simple custom relu operator implemented using CUDA function
* \brief simple custom relu and noisy relu operator implemented using CUDA function
*/

#include <iostream>
#include "lib_api.h"

#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N)
Expand Down Expand Up @@ -72,9 +74,9 @@ MXReturnValue forwardGPU(std::map<std::string, std::string> attrs,

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_forward<<<grid,block,0,cuda_stream>>>(out_data, in_data, N);
int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;

relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(out_data, in_data, N);

return MX_SUCCESS;
}
Expand All @@ -89,9 +91,9 @@ MXReturnValue backwardGPU(std::map<std::string, std::string> attrs,

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_backward<<<grid,block,0,cuda_stream>>>(in_grad, out_grad, in_data, N);
int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;

relu_gpu_backward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(in_grad, out_grad, in_data, N);

return MX_SUCCESS;
}
Expand Down Expand Up @@ -180,6 +182,80 @@ REGISTER_OP(my_state_relu)
.setCreateOpState(createOpStateCPU, "cpu")
.setCreateOpState(createOpStateGPU, "gpu");

/*
* Below is noisy ReLU operator example
* noisy ReLU is made from ReLU extended to include Gaussian noise
* forward - add Gaussian noise generated from normal distribution to each unit
* backward - gradient doesn't need to change since noise is constant
*/

#define NumRandomPerThread 64 // mxnet recommended random numbers generated per thread

__global__ void noisy_relu_gpu_forward(float *out, float *in, int64_t N, mx_gpu_rand_t* states, int step) {
// the launcher logic ensures tid less than NumGPURandomStates
int tid = blockIdx.x * blockDim.x + threadIdx.x;
// each thread generates unique sequence of random numbers
mx_gpu_rand_t thread_state = states[tid];
// each thread works on <step> number of calculation
int start = tid * step;
int end = start + step;
for (int i=start; i<end && i<N; ++i) {
float noise = curand_normal(&thread_state);
out[i] = in[i] + noise > 0 ? in[i] + noise : 0;
rondogency marked this conversation as resolved.
Show resolved Hide resolved
}
}

MXReturnValue noisyForwardCPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

mx_cpu_rand_t* states = res.get_cpu_rand_states();
std::normal_distribution<float> dist_normal;

for (int i=0; i<inputs[0].size(); ++i) {
float noise = dist_normal(*states);
out_data[i] = in_data[i] + noise > 0 ? in_data[i] + noise : 0;
}
return MX_SUCCESS;
}

MXReturnValue noisyForwardGPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();

// below is mxnet recommended workflow to parallel random number generating
int nthread = (N + NumRandomPerThread - 1) / NumRandomPerThread;
// we should not launch more threads than mxnet supported random number GPU states
int num_thread_need = nthread < MX_NUM_GPU_RANDOM_STATES ? nthread : MX_NUM_GPU_RANDOM_STATES;
// each cuda thread processes [step * tid, step * id + step) snippet of input tensor
int step = (N + num_thread_need - 1) / num_thread_need;
// this can ensure number of parallel threads less than mxnet supported random number states
int num_block = (num_thread_need + NumThreadPerBlock - 1) / NumThreadPerBlock;
rondogency marked this conversation as resolved.
Show resolved Hide resolved

noisy_relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(
out_data, in_data, N, res.get_gpu_rand_states(), step);

return MX_SUCCESS;
}

REGISTER_OP(my_noisy_relu)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setForward(noisyForwardCPU, "cpu")
.setForward(noisyForwardGPU, "gpu")
.setBackward(backwardCPU, "cpu")
.setBackward(backwardGPU, "gpu");

MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
Expand Down
43 changes: 27 additions & 16 deletions example/extensions/lib_custom_op/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
a = mx.nd.array([[-2,-1],[1,2]], ctx=mx.cpu())
b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())

print("--------start ndarray compute---------")
print("--------ndarray compute---------")
print(mx.nd.my_relu(a))
print(mx.nd.my_relu(b))
print(mx.nd.my_state_relu(a))
print(mx.nd.my_state_relu(b))

print("--------start symbolic compute--------")
print("--------symbolic compute--------")
c = mx.sym.Variable('c')
d = mx.sym.Variable('d')
e = mx.sym.my_relu(c)
Expand All @@ -55,30 +55,41 @@
print(out)
print(out_base)

print("--------start backward compute--------")
print("--------backward compute--------")
out_grad = mx.nd.ones((2,2), ctx=mx.gpu())
exe.backward([out_grad])
exe_base.backward([out_grad])
print(in_grad)
print(in_grad_base)

print("--------start testing larger ndarray---------")
a = mx.nd.uniform(shape=(100,100,100), ctx=mx.cpu())
print("--------test ndarray with size of 1 million---------")
b = mx.nd.uniform(shape=(100,100,100), ctx=mx.gpu())
mx.nd.waitall()
t1 = time.time()
r1 = mx.nd.my_relu(a)
r1 = mx.nd.my_relu(b)
mx.nd.waitall()
t2 = time.time()
r2 = mx.nd.my_relu(b)
r2 = mx.nd.relu(b)
mx.nd.waitall()
t3 = time.time()
r3 = mx.nd.relu(b)
mx.nd.waitall()
t4 = time.time()
print("CPU running time:")
print(t2 - t1)
print("GPU running time:")
print(t3 - t2)
print("Baseline GPU running time:")
print(t4 - t3)
print("Custom ReLU running time in ms:")
print((t2 - t1) * 1000)
print("Native ReLU running time in ms:")
print((t3 - t2) * 1000)

print("--------test noisy relu identical sequence---------")

a = mx.nd.ones(shape=(13,5), ctx=mx.cpu())
b = mx.nd.ones(shape=(13,5), ctx=mx.gpu())

mx.random.seed(128, ctx=mx.cpu())
print(mx.nd.my_noisy_relu(a))

mx.random.seed(128, ctx=mx.cpu())
print(mx.nd.my_noisy_relu(a))

mx.random.seed(128, ctx=mx.gpu())
print(mx.nd.my_noisy_relu(b))

mx.random.seed(128, ctx=mx.gpu())
print(mx.nd.my_noisy_relu(b))
57 changes: 46 additions & 11 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,14 @@
#include <iostream>
#include <utility>
#include <stdexcept>
#include <random>

#define MX_LIBRARY_VERSION 5
#if defined(__NVCC__)
#include <curand_kernel.h>
#endif

/* Make sure to update the version number everytime you make changes */
#define MX_LIBRARY_VERSION 6

/*!
* \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple
Expand Down Expand Up @@ -395,8 +401,8 @@ struct MXTensor {
stype == oth.stype;
}

// For dense, data_ptr points to data.
// For sparse, data_ptr points to MXSparse.
// For dense, data_ptr points to 1D flattened tensor data
// For sparse, data_ptr points to MXSparse
void *data_ptr;

// shape is in [2,3,4] format to represent high-dim tensor
Expand Down Expand Up @@ -426,9 +432,17 @@ typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t

#if defined(__NVCC__)
typedef cudaStream_t mx_stream_t;
typedef curandStatePhilox4_32_10_t mx_gpu_rand_t;
#else
typedef void* mx_stream_t;
typedef void* mx_gpu_rand_t;
#endif
typedef std::mt19937 mx_cpu_rand_t;

/*! \brief MXNet initialized random states for each device, used for parallelism */
/* Each thread should generate random number unique sequence out of different states */
#define MX_NUM_CPU_RANDOM_STATES 1024
#define MX_NUM_GPU_RANDOM_STATES 32768

/*!
* \brief provide resource APIs memory allocation mechanism to Forward/Backward functions
Expand All @@ -437,10 +451,12 @@ class OpResource {
public:
OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp,
xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream,
sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp)
sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp,
void* rng_cpu_states, void* rng_gpu_states)
: cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp),
cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream),
sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp) {}
sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp),
rand_cpu_states(rng_cpu_states), rand_gpu_states(rng_gpu_states) {}

/*! \brief allocate cpu memory controlled by MXNet */
void* alloc_cpu(int size) {
Expand All @@ -463,6 +479,19 @@ class OpResource {
&(sparse->data), &(sparse->indices), &(sparse->indptr));
}

/*! \brief get pointer to initialized and seeded random number states located on CPU */
/* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */
mx_cpu_rand_t* get_cpu_rand_states() {
return static_cast<mx_cpu_rand_t*>(rand_cpu_states);
}

/*! \brief get pointer to initialized and seeded random number states located on GPU */
/* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */
/* Note that if you are using cpu build, it will return a nullptr */
mx_gpu_rand_t* get_gpu_rand_states() {
return static_cast<mx_gpu_rand_t*>(rand_gpu_states);
}

private:
/*! \brief allocation lambda function */
xpu_malloc_t cpu_malloc, gpu_malloc;
Expand All @@ -474,6 +503,8 @@ class OpResource {
sparse_malloc_t sparse_malloc;
/*! \brief lambda function to return allocated sparse memory handle */
void *sparse_alloc;
/*! \brief cpu and gpu rng fully inited and seeded states */
void *rand_cpu_states, *rand_gpu_states;
};

/*!
Expand Down Expand Up @@ -997,7 +1028,8 @@ typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys,
void** in_indices, void** out_indices,
void** in_indptr, void** out_indptr,
int64_t* in_indices_shapes, int64_t* out_indices_shapes,
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes);
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
void* rng_cpu_states, void* rng_gpu_states);

#define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* keys,
Expand Down Expand Up @@ -1026,7 +1058,8 @@ typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
void** in_indices, void** out_indices,
void** in_indptr, void** out_indptr,
int64_t* in_indices_shapes, int64_t* out_indices_shapes,
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes);
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
void* rng_cpu_states, void* rng_gpu_states);

#define MXLIB_PARTREGSIZE_STR "_partRegSize"
typedef int (*partRegSize_t)(void);
Expand Down Expand Up @@ -1284,7 +1317,8 @@ extern "C" {
int* instypes, int* outstypes, void** in_indices, void** out_indices,
void** in_indptr, void** out_indptr,
int64_t* in_indices_shapes, int64_t* out_indices_shapes,
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) {
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
void* rng_cpu_states, void* rng_gpu_states) {
// create map of attributes from list
std::map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
Expand Down Expand Up @@ -1345,7 +1379,7 @@ extern "C" {
}

OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
cuda_stream, sparse_malloc, sparse_alloc);
cuda_stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);
return fcomp(attrs, inputs, outputs, res);
}

Expand Down Expand Up @@ -1419,7 +1453,8 @@ extern "C" {
int* instypes, int* outstypes, void** in_indices, void** out_indices,
void** in_indptr, void** out_indptr,
int64_t* in_indices_shapes, int64_t* out_indices_shapes,
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) {
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
void* rng_cpu_states, void* rng_gpu_states) {
// create a vector of tensors for inputs
std::vector<MXTensor> inputs(num_in);
// create a vector for sparse inputs
Expand Down Expand Up @@ -1476,7 +1511,7 @@ extern "C" {
}

OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
stream, sparse_malloc, sparse_alloc);
stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);

CustomStatefulOp* op_ptr = reinterpret_cast<CustomStatefulOp*>(state_op);
if (is_forward) {
Expand Down
Loading