Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
changed scope of using namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanhenneking committed Jul 29, 2017
1 parent 1b46f21 commit 6cdf419
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/operator/tensor/cast_storage-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@

namespace mxnet {
namespace op {
using mshadow::cuda::kBaseThreadNum;
using mshadow::Shape1;
using mxnet_op::Kernel;

/*!
* \brief Thread kernel for marking non-zero rows of a tensor.
Expand Down Expand Up @@ -65,7 +62,7 @@ struct MarkRspRowIdxWarpKernel {
const index_t num_rows,
const index_t row_length) {
typedef cub::WarpReduce<index_t> WarpReduce;
const index_t warps_per_block = kBaseThreadNum / 32;
const index_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
__shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block];

const index_t warp_id = tid / 32; // global warp id
Expand Down Expand Up @@ -105,6 +102,7 @@ struct MarkRspRowIdxBlockKernel {
const DType* dns,
const index_t num_rows,
const index_t row_length) {
using mshadow::cuda::kBaseThreadNum;
typedef cub::BlockReduce<index_t, kBaseThreadNum> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
if (blockIdx.x < num_rows) {
Expand Down Expand Up @@ -195,13 +193,15 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
CHECK(rsp != nullptr);
CHECK_EQ(rsp->storage_type(), kRowSparseStorage);
CHECK_EQ(dns.shape_, rsp->shape());
using mshadow::Shape1;
using mxnet_op::Kernel;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type
const index_t num_rows = dns.shape_[0];
const index_t row_length = dns.shape_.ProdShape(1, dns.shape_.ndim());
const index_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
const index_t threads_per_block = kBaseThreadNum;
const index_t threads_per_block = mshadow::cuda::kBaseThreadNum;
const index_t min_num_warps = 512;
index_t num_threads;
// TODO: remove kernel dependency on warpSize=32
Expand Down Expand Up @@ -371,7 +371,7 @@ struct FillCsrIndPtrWarpKernel {
IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
typedef cub::WarpReduce<index_t> WarpReduce;
const index_t warps_per_block = kBaseThreadNum / 32;
const index_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
__shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block];

if (tid == 0) {
Expand Down Expand Up @@ -407,7 +407,7 @@ struct FillCsrColIdxAndValsWarpKernel {
const IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
typedef cub::WarpScan<index_t> WarpScan;
const index_t warps_per_block = kBaseThreadNum / 32;
const index_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
__shared__ typename WarpScan::TempStorage temp_storage[warps_per_block];
__shared__ volatile index_t warp_nnz[warps_per_block];

Expand Down Expand Up @@ -455,6 +455,7 @@ struct FillCsrIndPtrBlockKernel {
__device__ __forceinline__ static void Map(int tid,
IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
using mshadow::cuda::kBaseThreadNum;
typedef cub::BlockReduce<index_t, kBaseThreadNum> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

Expand Down Expand Up @@ -487,6 +488,7 @@ struct FillCsrColIdxAndValsBlockKernel {
DType* val, CType* col_idx,
const IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
using mshadow::cuda::kBaseThreadNum;
typedef cub::BlockScan<index_t, kBaseThreadNum> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ volatile index_t block_nnz;
Expand Down Expand Up @@ -534,14 +536,16 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx,
CHECK_EQ(csr->storage_type(), kCSRStorage);
CHECK_EQ(dns.shape_.ndim(), 2);
CHECK_EQ(dns.shape_, csr->shape());
using mshadow::Shape1;
using mxnet_op::Kernel;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col_idx type
const index_t num_rows = dns.shape_[0];
const index_t num_cols = dns.shape_[1];
const index_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
const index_t threads_per_block = kBaseThreadNum;
const index_t threads_per_block = mshadow::cuda::kBaseThreadNum;
const index_t min_num_warps = 512;
index_t num_threads;
// TODO: remove kernel dependency on warpSize=32
Expand Down

0 comments on commit 6cdf419

Please sign in to comment.