diff --git a/driver/random.hpp b/driver/random.hpp index 6398048dde..b3be81f56e 100644 --- a/driver/random.hpp +++ b/driver/random.hpp @@ -91,6 +91,7 @@ inline T gen_0_to_B(T B) template inline T gen_A_to_B(T A, T B) { + assert(B > A); return gen_0_to_B(B - A) + A; } diff --git a/fin b/fin index 26b5c32864..afc1a8d87e 160000 --- a/fin +++ b/fin @@ -1 +1 @@ -Subproject commit 26b5c328642a6af5041539ceae36b9340829384b +Subproject commit afc1a8d87e6d00c82903942007bb370ee1f6c760 diff --git a/src/include/miopen/convolution.hpp b/src/include/miopen/convolution.hpp index bac0133106..35c494eab2 100644 --- a/src/include/miopen/convolution.hpp +++ b/src/include/miopen/convolution.hpp @@ -36,6 +36,7 @@ #include #include #include +#include #include @@ -404,6 +405,9 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor friend void to_json(nlohmann::json& json, const ConvolutionDescriptor& conv); friend void from_json(const nlohmann::json& json, ConvolutionDescriptor& conv); + +private: + void ValidateTensors(const ConvTensors& conv_tensors) const; }; void ConvolutionBackwardBias(const Handle& handle, diff --git a/src/kernels/gpu_reference_kernel/naive_conv.cpp b/src/kernels/gpu_reference_kernel/naive_conv.cpp index b243b1234a..125eff94f3 100644 --- a/src/kernels/gpu_reference_kernel/naive_conv.cpp +++ b/src/kernels/gpu_reference_kernel/naive_conv.cpp @@ -126,9 +126,9 @@ inline __device__ __host__ int8_t cast_to(const int32_t& val) /// composable_kernel (CK) treats G dimension. Which is why nchw should be ngchw, /// and nhwc should be nhwgc. Same follows for the 3D case. /// -/// - strides here are in the little-endian order, i.e., for NHWC, stride for N is -/// at index 3 while stride for C is at index 0. This is reverse of how strides are -/// stored in tensor descriptors, which are big-endian. +/// - strides here are stored right to left, i.e., for NHWC, stride for N is +/// at index 3 while stride for C is at index 0. This is different from how the +/// tensor descriptors store strides, which is always NCHW order, left-to-right. template inline __device__ void naive_conv_fwd_nchw(const src_data_t* __restrict__ p_in, diff --git a/src/ocl/convolutionocl.cpp b/src/ocl/convolutionocl.cpp index 94b083577d..d66186577c 100644 --- a/src/ocl/convolutionocl.cpp +++ b/src/ocl/convolutionocl.cpp @@ -287,30 +287,6 @@ void ConvolutionDescriptor::FindConvFwdAlgorithm(Handle& handle, namespace { -void ValidateConvTensors(const ConvTensors& tensors) -{ - const auto invalid_buffers = - tensors.x == nullptr || tensors.w == nullptr || tensors.y == nullptr; - - const auto tensor_sizes_not_matched = tensors.xDesc.GetSize() != tensors.yDesc.GetSize() || - tensors.xDesc.GetSize() != tensors.wDesc.GetSize(); - - const auto trivial_tensor_types_not_matched = - tensors.xDesc.GetType() != tensors.yDesc.GetType() && tensors.xDesc.GetType() != miopenInt8; - - // if(xDesc.GetLengths()[1] != wDesc.GetLengths()[1]) { - // MIOPEN_THROW(miopenStatusBadParm); - //} - - const auto x_tensor_invalid = tensors.xDesc.GetSize() < 3; - - const auto bad_parameters = invalid_buffers || tensor_sizes_not_matched || - trivial_tensor_types_not_matched || x_tensor_invalid; - - if(bad_parameters) - MIOPEN_THROW(miopenStatusBadParm); -} - void ValidateAlphaBeta(const void* alpha, const void* beta) { if(!float_equal(*(static_cast(alpha)), 1.0) || @@ -401,6 +377,88 @@ static void ConvForwardCheckNumerics(const Handle& handle, } } +void ConvolutionDescriptor::ValidateTensors(const ConvTensors& tensors) const +{ + + // Group stride in current TensorDescriptor is implicit. When invoking kernels, + // we need to add the group dimension G and compute its stride. We want the stride + // left of C to be a multiple of group count G. e.g. for NCHW, the stride for N + // should be a multiple of G so that we can compute the strides for NGCHW + auto bad_group_stride = [this](const TensorDescriptor& td) { + auto l = td.GetLayout_t(); + int g_stride_index = -1; + if(l == miopenTensorNCHW || l == miopenTensorNCDHW) + { + g_stride_index = 0; // stride index for N; + } + else if(l == miopenTensorNHWC || l == miopenTensorNDHWC) + { + // stride index for W. Normally this would be 2nd-last stride but we store + // strides in NCHW order for some weird reason. + g_stride_index = td.GetStrides().size() - 1; + } + else + { + MIOPEN_THROW(miopenStatusInternalError, "Layout not supported for grouped convolution"); + } + + if(g_stride_index != -1) + { + return (td.GetStrides()[g_stride_index] % this->group_count) != 0; + } + + return false; + }; + + // invalid_buffers + if(tensors.x == nullptr || tensors.w == nullptr || tensors.y == nullptr) + { + MIOPEN_THROW(miopenStatusBadParm, "One of the convolution tensors is null"); + } + + // x_tensor_invalid = + if(tensors.xDesc.GetSize() < 3) + { + MIOPEN_THROW(miopenStatusBadParm, "input tensor's number of dimensions is wrong"); + } + + // tensor_sizes_not_matched = + if(tensors.xDesc.GetSize() != tensors.yDesc.GetSize() || + tensors.xDesc.GetSize() != tensors.wDesc.GetSize()) + { + MIOPEN_THROW(miopenStatusBadParm, + "number of dimensions mismatch between input, output and weights tensors"); + } + + // trivial_tensor_types_not_matched = + if(tensors.xDesc.GetType() != tensors.yDesc.GetType() && + tensors.xDesc.GetType() != miopenInt8 && tensors.xDesc.GetType() != miopenInt8x4) + { + MIOPEN_THROW(miopenStatusBadParm, "input/output tensor data types do not match"); + } + + // check for bad_group_stride. This applies for input and output only. There + // is no check for weight tensor currently. + // no need to check for group_count == 1 + + if((this->group_count > 1) && bad_group_stride(tensors.xDesc)) + { + MIOPEN_THROW( + miopenStatusBadParm, + "Invalid input tensor strides. Channel stride must be a multiple of group count"); + } + if((this->group_count > 1) && bad_group_stride(tensors.yDesc)) + { + MIOPEN_THROW( + miopenStatusBadParm, + "Invalid output tensor strides. Channel stride must be a multiple of group count"); + } + + // if(xDesc.GetLengths()[1] != wDesc.GetLengths()[1]) { + // MIOPEN_THROW(miopenStatusBadParm); + //} +} + void ConvolutionDescriptor::ConvolutionForward(Handle& handle, const void* alpha, const TensorDescriptor& xDesc, @@ -416,13 +474,8 @@ void ConvolutionDescriptor::ConvolutionForward(Handle& handle, { MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize); - if(!(xDesc.IsPacked() && wDesc.IsPacked() && yDesc.IsPacked())) - { - MIOPEN_THROW(miopenStatusNotImplemented, "Only fully packed tensors are supported"); - } - const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y}; - ValidateConvTensors(tensors); + ValidateTensors(tensors); ValidateAlphaBeta(alpha, beta); ConvForwardCheckNumerics(handle, tensors, [&]() { @@ -735,7 +788,7 @@ void ConvolutionDescriptor::ConvolutionForwardImmediate(Handle& handle, MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize); const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y}; - ValidateConvTensors(tensors); + ValidateTensors(tensors); if(!solver_id.IsValid()) MIOPEN_THROW(miopenStatusBadParm); @@ -871,7 +924,7 @@ void ConvolutionDescriptor::ConvolutionBackwardData(Handle& handle, auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx}; - ValidateConvTensors(tensors); + ValidateTensors(tensors); ValidateAlphaBeta(alpha, beta); ConvBwdCheckNumerics(handle, tensors, beta, [&]() { @@ -937,7 +990,7 @@ void ConvolutionDescriptor::ConvolutionBackwardImmediate(Handle& handle, MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize); auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx}; - ValidateConvTensors(tensors); + ValidateTensors(tensors); static const float beta = 0.0f; ConvBwdCheckNumerics(handle, tensors, &beta, [&]() { @@ -1071,7 +1124,7 @@ void ConvolutionDescriptor::ConvolutionBackwardWeights(const Handle& handle, { MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize); decltype(auto) tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw}; - ValidateConvTensors(tensors); + ValidateTensors(tensors); ValidateAlphaBeta(alpha, beta); if(xDesc.GetType() == miopenInt8) @@ -1134,7 +1187,7 @@ void ConvolutionDescriptor::ConvolutionWrwImmediate(Handle& handle, { MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize); auto tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw}; - ValidateConvTensors(tensors); + ValidateTensors(tensors); if(xDesc.GetType() == miopenInt8) MIOPEN_THROW(miopenStatusBadParm); diff --git a/src/solver/conv_direct_naive_conv.cpp b/src/solver/conv_direct_naive_conv.cpp index 992b196b45..f87511f911 100644 --- a/src/solver/conv_direct_naive_conv.cpp +++ b/src/solver/conv_direct_naive_conv.cpp @@ -111,11 +111,15 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_USE_PACKED_KERNELS); std::string ConvDirectNaiveConvKernelName(const ProblemDescription& problem) { std::ostringstream kernel_name; + + /// \todo remove packed reference convolution kernels --amberhassaan +#ifndef NDEBUG // enable in debug mode only if(miopen::IsEnabled(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_USE_PACKED_KERNELS())) { kernel_name << "naive_conv_packed_"; } else +#endif { kernel_name << "naive_conv_nonpacked_"; } diff --git a/src/solver/conv_direct_naive_conv_bwd.cpp b/src/solver/conv_direct_naive_conv_bwd.cpp index 1a28f8aae6..dea91c9ecf 100644 --- a/src/solver/conv_direct_naive_conv_bwd.cpp +++ b/src/solver/conv_direct_naive_conv_bwd.cpp @@ -134,12 +134,8 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); - const auto is_f8 = [&]() { - if(kernel.kernel_file == "fp8_naive_conv.cpp") - return true; - else - return false; - }(); + const auto is_f8 = (kernel.kernel_file == "fp8_naive_conv.cpp"); + kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem); int G_stride_idx = conv_internal::GetGroupStrideIndex(problem); diff --git a/src/solver/conv_direct_naive_conv_fwd.cpp b/src/solver/conv_direct_naive_conv_fwd.cpp index a4656d929a..5bc25a2367 100644 --- a/src/solver/conv_direct_naive_conv_fwd.cpp +++ b/src/solver/conv_direct_naive_conv_fwd.cpp @@ -122,12 +122,6 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, KernelInfo kernel; kernel.kernel_file = ConvDirectNaiveConvKernelFile(ctx, problem); - const auto is_f8 = [&]() { - if(kernel.kernel_file == "fp8_naive_conv.cpp") - return true; - else - return false; - }(); kernel.kernel_name = ConvDirectNaiveConvKernelName(problem); kernel.g_wk.clear(); @@ -139,6 +133,8 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); + const auto is_f8 = (kernel.kernel_file == "fp8_naive_conv.cpp"); + kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem); int G_stride_idx = conv_internal::GetGroupStrideIndex(problem); diff --git a/src/solver/conv_direct_naive_conv_wrw.cpp b/src/solver/conv_direct_naive_conv_wrw.cpp index dfe1c342b0..a8c4d40e0b 100644 --- a/src/solver/conv_direct_naive_conv_wrw.cpp +++ b/src/solver/conv_direct_naive_conv_wrw.cpp @@ -121,13 +121,9 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); + const auto is_f8 = (kernel.kernel_file == "fp8_naive_conv.cpp"); + kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem); - const auto is_f8 = [&]() { - if(kernel.kernel_file == "fp8_naive_conv.cpp") - return true; - else - return false; - }(); int G_stride_idx = conv_internal::GetGroupStrideIndex(problem); diff --git a/test/gpu_reference_kernel.cpp b/test/gpu_reference_kernel.cpp index be8f3f8430..e166781b9b 100644 --- a/test/gpu_reference_kernel.cpp +++ b/test/gpu_reference_kernel.cpp @@ -24,6 +24,8 @@ * *******************************************************************************/ +#include +#include #include #include #include @@ -73,17 +75,9 @@ std::string tensor_layout_to_string(tensor_layout_t layout) struct gpu_reference_kernel_base { miopenHandle_t handle{}; -#if MIOPEN_BACKEND_OPENCL - cl_command_queue q{}; -#endif - gpu_reference_kernel_base() - { - miopenCreate(&handle); -#if MIOPEN_BACKEND_OPENCL - miopenGetStream(handle, &q); -#endif - } + gpu_reference_kernel_base() { miopenCreate(&handle); } + ~gpu_reference_kernel_base() { miopenDestroy(handle); } static int conv_out_size(int in_size, int pad, int dilation, int ksize, int stride) @@ -308,6 +302,21 @@ static std::string miopen_type_to_string(miopenDataType_t type) return "n/a"; } +/// input: a vector of lengths of dims in a tensor +/// multiply each element with a random constant integer +void pad_tensor_strides(std::vector& strides) +{ + constexpr int min_stride_multiplier = 1; + constexpr int max_stride_multiplier = 5; + + auto c = prng::gen_A_to_B(min_stride_multiplier, max_stride_multiplier); + for(auto& v : strides) + { + // cppcheck-suppress useStlAlgorithm + v = v * c; + } +} + template in_len({n, c, hi, wi}); std::vector wei_len({k, c_per_group, fy, fx}); @@ -360,28 +364,25 @@ struct gpu_reference_conv_2d : gpu_reference_kernel_base miopen::tensor_layout_to_strides(wei_len, layout_default, layout_string, wei_strides); miopen::tensor_layout_to_strides(out_len, layout_default, layout_string, out_strides); + pad_tensor_strides(in_strides); + pad_tensor_strides(wei_strides); + pad_tensor_strides(out_strides); + tensor in(in_len, in_strides); tensor wei(wei_len, wei_strides); tensor out(out_len, out_strides); -#if MIOPEN_BACKEND_OPENCL - cl_context ctx; - clGetCommandQueueInfo(q, CL_QUEUE_CONTEXT, sizeof(cl_context), &ctx, nullptr); - cl_int status = CL_SUCCESS; - cl_mem in_dev = - clCreateBuffer(ctx, CL_MEM_READ_WRITE, sizeof(TRef) * in_sz, nullptr, &status); - cl_mem wei_dev = - clCreateBuffer(ctx, CL_MEM_READ_WRITE, sizeof(TRef) * wei_sz, nullptr, nullptr); - cl_mem out_dev = - clCreateBuffer(ctx, CL_MEM_READ_WRITE, sizeof(Tout) * out_sz, nullptr, nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + + auto in_sz = in.data.size(); + auto wei_sz = wei.data.size(); + auto out_sz = out.data.size(); + void* in_dev; void* wei_dev; void* out_dev; EXPECT(hipMalloc(&in_dev, sizeof(TRef) * in_sz) == hipSuccess); EXPECT(hipMalloc(&wei_dev, sizeof(TRef) * wei_sz) == hipSuccess); EXPECT(hipMalloc(&out_dev, sizeof(Tout) * out_sz) == hipSuccess); -#endif + EXPECT(miopenCreateConvolutionDescriptor(&convDesc) == miopenStatusSuccess); EXPECT(miopenInitConvolutionNdDescriptor(convDesc, 2, @@ -417,27 +418,9 @@ struct gpu_reference_conv_2d : gpu_reference_kernel_base // initialize data with integer rand_tensor_integer(in); rand_tensor_integer(wei); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueWriteBuffer(q, - in_dev, - CL_TRUE, - 0, - sizeof(TRef) * in_sz, - in.data.data(), - 0, - nullptr, - nullptr); - status |= clEnqueueWriteBuffer(q, - wei_dev, - CL_TRUE, - 0, - sizeof(TRef) * wei_sz, - wei.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + /// \ref copy_non_packed_output_before_convolution + rand_tensor_integer(out); + EXPECT(hipMemcpy( in_dev, in.data.data(), sizeof(TRef) * in_sz, hipMemcpyHostToDevice) == hipSuccess); @@ -445,7 +428,19 @@ struct gpu_reference_conv_2d : gpu_reference_kernel_base wei.data.data(), sizeof(TRef) * wei_sz, hipMemcpyHostToDevice) == hipSuccess); -#endif + /// \anchor copy_non_packed_output_before_convolution + /// \note Output is a non-packed tensor, which means there are + /// elements that convolution will not update. In order to verify + /// the convolution result, the GPU buffer should have the same + /// data as the CPU in both update and not-updated elements. + /// Therefore, we copy the output to the GPU buffer after + /// initializing it with random values. + /// + EXPECT(hipMemcpy(out_dev, + out.data.data(), + sizeof(Tout) * out_sz, + hipMemcpyHostToDevice) == hipSuccess); + cpu_convolution_forward(miopen::deref(convDesc).GetSpatialDimension(), in, wei, @@ -470,23 +465,11 @@ struct gpu_reference_conv_2d : gpu_reference_kernel_base miopenStatusSuccess); tensor out_host(out_len, out_strides); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueReadBuffer(q, - out_dev, - CL_TRUE, - 0, - sizeof(Tout) * out_sz, - out_host.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP EXPECT(hipMemcpy(out_host.data.data(), out_dev, sizeof(Tout) * out_sz, hipMemcpyDeviceToHost) == hipSuccess); -#endif + // we expect excact match, since use integer valid_result = verify_tensor(out_host, out); } @@ -495,36 +478,22 @@ struct gpu_reference_conv_2d : gpu_reference_kernel_base // initialize data with integer rand_tensor_integer(out); rand_tensor_integer(wei); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueWriteBuffer(q, - out_dev, - CL_TRUE, - 0, - sizeof(TRef) * out_sz, - out.data.data(), - 0, - nullptr, - nullptr); - status |= clEnqueueWriteBuffer(q, - wei_dev, - CL_TRUE, - 0, - sizeof(TRef) * wei_sz, - wei.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + /// \ref copy_non_packed_output_before_convolution + rand_tensor_integer(in); + /// \ref copy_non_packed_output_before_convolution + + EXPECT(hipMemcpy( + in_dev, in.data.data(), sizeof(TRef) * in_sz, hipMemcpyHostToDevice) == + hipSuccess); EXPECT(hipMemcpy(out_dev, out.data.data(), - sizeof(TRef) * out_sz, + sizeof(Tout) * out_sz, hipMemcpyHostToDevice) == hipSuccess); EXPECT(hipMemcpy(wei_dev, wei.data.data(), sizeof(TRef) * wei_sz, hipMemcpyHostToDevice) == hipSuccess); -#endif + cpu_convolution_backward_data(miopen::deref(convDesc).GetSpatialDimension(), in, wei, @@ -549,23 +518,11 @@ struct gpu_reference_conv_2d : gpu_reference_kernel_base miopenStatusSuccess); tensor in_host(in_len, in_strides); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueReadBuffer(q, - in_dev, - CL_TRUE, - 0, - sizeof(TRef) * in_sz, - in_host.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + EXPECT(hipMemcpy(in_host.data.data(), in_dev, sizeof(TRef) * in_sz, hipMemcpyDeviceToHost) == hipSuccess); -#endif // we expect excact match, since use integer valid_result = verify_tensor(in_host, in); @@ -574,35 +531,22 @@ struct gpu_reference_conv_2d : gpu_reference_kernel_base { rand_tensor_integer(in); rand_tensor_integer(out); -#if MIOPEN_BACKEND_OPENCL - status |= clEnqueueWriteBuffer(q, - in_dev, - CL_TRUE, - 0, - sizeof(TRef) * in_sz, - in.data.data(), - 0, - nullptr, - nullptr); - status |= clEnqueueWriteBuffer(q, - out_dev, - CL_TRUE, - 0, - sizeof(TRef) * out_sz, - out.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + /// \ref copy_non_packed_output_before_convolution + rand_tensor_integer(wei); + EXPECT(hipMemcpy( in_dev, in.data.data(), sizeof(TRef) * in_sz, hipMemcpyHostToDevice) == hipSuccess); + /// \ref copy_non_packed_output_before_convolution + EXPECT(hipMemcpy(wei_dev, + wei.data.data(), + sizeof(TRef) * wei_sz, + hipMemcpyHostToDevice) == hipSuccess); EXPECT(hipMemcpy(out_dev, out.data.data(), - sizeof(TRef) * out_sz, + sizeof(Tout) * out_sz, hipMemcpyHostToDevice) == hipSuccess); -#endif + cpu_convolution_backward_weight(miopen::deref(convDesc).GetSpatialDimension(), in, wei, @@ -627,23 +571,11 @@ struct gpu_reference_conv_2d : gpu_reference_kernel_base miopenStatusSuccess); tensor wei_host(wei_len, wei_strides); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueReadBuffer(q, - wei_dev, - CL_TRUE, - 0, - sizeof(TRef) * wei_sz, - wei_host.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + EXPECT(hipMemcpy(wei_host.data.data(), wei_dev, sizeof(TRef) * wei_sz, hipMemcpyDeviceToHost) == hipSuccess); -#endif // we expect excact match, since use integer valid_result = verify_tensor(wei_host, wei); @@ -665,15 +597,10 @@ struct gpu_reference_conv_2d : gpu_reference_kernel_base miopenDestroyTensorDescriptor(inDesc); miopenDestroyTensorDescriptor(weiDesc); miopenDestroyTensorDescriptor(outDesc); -#if MIOPEN_BACKEND_OPENCL - clReleaseMemObject(in_dev); - clReleaseMemObject(wei_dev); - clReleaseMemObject(out_dev); -#elif MIOPEN_BACKEND_HIP + hipFree(in_dev); hipFree(wei_dev); hipFree(out_dev); -#endif }; iterate_conv_2d(run_conv_2d); @@ -717,11 +644,6 @@ struct gpu_reference_conv_3d : gpu_reference_kernel_base int wo = conv_out_size(wi, px, dx, fx, sx); int do_ = conv_out_size(di, pz, dz, fz, sz); int c_per_group = c / g; - int k_per_group = k / g; - - int in_sz = g * n * c_per_group * di * hi * wi; - int wei_sz = g * k_per_group * c_per_group * fz * fy * fx; - int out_sz = g * n * k_per_group * do_ * ho * wo; std::vector in_len({n, c, di, hi, wi}); std::vector wei_len({k, c_per_group, fz, fy, fx}); @@ -738,28 +660,26 @@ struct gpu_reference_conv_3d : gpu_reference_kernel_base miopen::tensor_layout_to_strides(wei_len, layout_default, layout_string, wei_strides); miopen::tensor_layout_to_strides(out_len, layout_default, layout_string, out_strides); + pad_tensor_strides(in_strides); + pad_tensor_strides(wei_strides); + pad_tensor_strides(out_strides); + tensor in(in_len, in_strides); tensor wei(wei_len, wei_strides); tensor out(out_len, out_strides); -#if MIOPEN_BACKEND_OPENCL - cl_context ctx; - clGetCommandQueueInfo(q, CL_QUEUE_CONTEXT, sizeof(cl_context), &ctx, nullptr); - cl_int status = CL_SUCCESS; - cl_mem in_dev = - clCreateBuffer(ctx, CL_MEM_READ_WRITE, sizeof(TRef) * in_sz, nullptr, &status); - cl_mem wei_dev = - clCreateBuffer(ctx, CL_MEM_READ_WRITE, sizeof(TRef) * wei_sz, nullptr, nullptr); - cl_mem out_dev = - clCreateBuffer(ctx, CL_MEM_READ_WRITE, sizeof(Tout) * out_sz, nullptr, nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + + auto in_sz = in.data.size(); + auto wei_sz = wei.data.size(); + auto out_sz = out.data.size(); + void* in_dev; void* wei_dev; void* out_dev; + EXPECT(hipMalloc(&in_dev, sizeof(TRef) * in_sz) == hipSuccess); EXPECT(hipMalloc(&wei_dev, sizeof(TRef) * wei_sz) == hipSuccess); EXPECT(hipMalloc(&out_dev, sizeof(Tout) * out_sz) == hipSuccess); -#endif + EXPECT(miopenCreateConvolutionDescriptor(&convDesc) == miopenStatusSuccess); EXPECT(miopenInitConvolutionNdDescriptor(convDesc, 3, @@ -795,35 +715,21 @@ struct gpu_reference_conv_3d : gpu_reference_kernel_base // initialize data with integer rand_tensor_integer(in); rand_tensor_integer(wei); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueWriteBuffer(q, - in_dev, - CL_TRUE, - 0, - sizeof(TRef) * in_sz, - in.data.data(), - 0, - nullptr, - nullptr); - status |= clEnqueueWriteBuffer(q, - wei_dev, - CL_TRUE, - 0, - sizeof(TRef) * wei_sz, - wei.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + /// \ref copy_non_packed_output_before_convolution + rand_tensor_integer(out); + EXPECT(hipMemcpy( in_dev, in.data.data(), sizeof(TRef) * in_sz, hipMemcpyHostToDevice) == hipSuccess); + /// \ref copy_non_packed_output_before_convolution + EXPECT(hipMemcpy(out_dev, + out.data.data(), + sizeof(Tout) * out_sz, + hipMemcpyHostToDevice) == hipSuccess); EXPECT(hipMemcpy(wei_dev, wei.data.data(), sizeof(TRef) * wei_sz, hipMemcpyHostToDevice) == hipSuccess); -#endif cpu_convolution_forward(miopen::deref(convDesc).GetSpatialDimension(), in, @@ -849,23 +755,11 @@ struct gpu_reference_conv_3d : gpu_reference_kernel_base miopenStatusSuccess); tensor out_host(out_len, out_strides); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueReadBuffer(q, - out_dev, - CL_TRUE, - 0, - sizeof(Tout) * out_sz, - out_host.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + EXPECT(hipMemcpy(out_host.data.data(), out_dev, sizeof(Tout) * out_sz, hipMemcpyDeviceToHost) == hipSuccess); -#endif // we expect excact match, since use integer valid_result = verify_tensor(out_host, out); @@ -875,36 +769,22 @@ struct gpu_reference_conv_3d : gpu_reference_kernel_base // initialize data with integer rand_tensor_integer(out); rand_tensor_integer(wei); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueWriteBuffer(q, - out_dev, - CL_TRUE, - 0, - sizeof(TRef) * out_sz, - out.data.data(), - 0, - nullptr, - nullptr); - status |= clEnqueueWriteBuffer(q, - wei_dev, - CL_TRUE, - 0, - sizeof(TRef) * wei_sz, - wei.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + /// \ref copy_non_packed_output_before_convolution + rand_tensor_integer(in); + + /// \ref copy_non_packed_output_before_convolution + EXPECT(hipMemcpy( + in_dev, in.data.data(), sizeof(TRef) * in_sz, hipMemcpyHostToDevice) == + hipSuccess); EXPECT(hipMemcpy(out_dev, out.data.data(), - sizeof(TRef) * out_sz, + sizeof(Tout) * out_sz, hipMemcpyHostToDevice) == hipSuccess); EXPECT(hipMemcpy(wei_dev, wei.data.data(), sizeof(TRef) * wei_sz, hipMemcpyHostToDevice) == hipSuccess); -#endif + cpu_convolution_backward_data(miopen::deref(convDesc).GetSpatialDimension(), in, wei, @@ -929,23 +809,11 @@ struct gpu_reference_conv_3d : gpu_reference_kernel_base miopenStatusSuccess); tensor in_host(in_len, in_strides); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueReadBuffer(q, - in_dev, - CL_TRUE, - 0, - sizeof(TRef) * in_sz, - in_host.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + EXPECT(hipMemcpy(in_host.data.data(), in_dev, sizeof(TRef) * in_sz, hipMemcpyDeviceToHost) == hipSuccess); -#endif // we expect excact match, since use integer valid_result = verify_tensor(in_host, in); @@ -954,35 +822,22 @@ struct gpu_reference_conv_3d : gpu_reference_kernel_base { rand_tensor_integer(in, 3, -2); rand_tensor_integer(out, 3, -2); -#if MIOPEN_BACKEND_OPENCL - status |= clEnqueueWriteBuffer(q, - in_dev, - CL_TRUE, - 0, - sizeof(TRef) * in_sz, - in.data.data(), - 0, - nullptr, - nullptr); - status |= clEnqueueWriteBuffer(q, - out_dev, - CL_TRUE, - 0, - sizeof(TRef) * out_sz, - out.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + /// \ref copy_non_packed_output_before_convolution + rand_tensor_integer(wei); + EXPECT(hipMemcpy( in_dev, in.data.data(), sizeof(TRef) * in_sz, hipMemcpyHostToDevice) == hipSuccess); + /// \ref copy_non_packed_output_before_convolution + EXPECT(hipMemcpy(wei_dev, + wei.data.data(), + sizeof(TRef) * wei_sz, + hipMemcpyHostToDevice) == hipSuccess); EXPECT(hipMemcpy(out_dev, out.data.data(), - sizeof(TRef) * out_sz, + sizeof(Tout) * out_sz, hipMemcpyHostToDevice) == hipSuccess); -#endif + cpu_convolution_backward_weight(miopen::deref(convDesc).GetSpatialDimension(), in, wei, @@ -1007,23 +862,11 @@ struct gpu_reference_conv_3d : gpu_reference_kernel_base miopenStatusSuccess); tensor wei_host(wei_len, wei_strides); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueReadBuffer(q, - wei_dev, - CL_TRUE, - 0, - sizeof(TRef) * wei_sz, - wei_host.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); -#elif MIOPEN_BACKEND_HIP + EXPECT(hipMemcpy(wei_host.data.data(), wei_dev, sizeof(TRef) * wei_sz, hipMemcpyDeviceToHost) == hipSuccess); -#endif // we expect excact match, since use integer valid_result = verify_tensor(wei_host, wei, 8.0); // max possible int @@ -1049,15 +892,9 @@ struct gpu_reference_conv_3d : gpu_reference_kernel_base miopenDestroyTensorDescriptor(weiDesc); miopenDestroyTensorDescriptor(outDesc); -#if MIOPEN_BACKEND_OPENCL - clReleaseMemObject(in_dev); - clReleaseMemObject(wei_dev); - clReleaseMemObject(out_dev); -#elif MIOPEN_BACKEND_HIP hipFree(in_dev); hipFree(wei_dev); hipFree(out_dev); -#endif }; iterate_conv_3d(run_conv_3d); diff --git a/test/gtest/conv_api_strided_tensors.cpp b/test/gtest/conv_api_strided_tensors.cpp index 2a59dcd696..04d56ec908 100644 --- a/test/gtest/conv_api_strided_tensors.cpp +++ b/test/gtest/conv_api_strided_tensors.cpp @@ -139,7 +139,9 @@ class ConvStridedTensors : public ::testing::Test std::vector h_output; }; -// This test should be replaced when strided tensors are fully implemented +/// \todo re-enable this test after NCDHW grouped convolution lands (PR 2429) +/// \todo add cpu reference convolution for verification --amberhassaan +#if 0 TEST_F(ConvStridedTensors, ConvStridedTensorsNotImplemented) { auto device = Device(handle); @@ -178,9 +180,8 @@ TEST_F(ConvStridedTensors, ConvStridedTensorsNotImplemented) const float alpha = 1.f; const float beta = 0.f; - // miopenConvolutionForward() must return error if the format is not supported ASSERT_TRUE(device.Synchronize()); - ASSERT_NE(miopenConvolutionForward(handle, + ASSERT_EQ(miopenConvolutionForward(handle, &alpha, input_descr, d_input.Data(), @@ -196,3 +197,4 @@ TEST_F(ConvStridedTensors, ConvStridedTensorsNotImplemented) miopenStatusSuccess); ASSERT_TRUE(device.Synchronize()); } +#endif