diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 631d404845f7..4040d82b33e7 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -64,8 +64,13 @@ class OpenCLWrappedFunc { } // setup arguments. for (cl_uint i = 0; i < arg_size_.size(); ++i) { - auto* arg = static_cast(void_args[i]); - OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg->buffer)); + void* arg = nullptr; + if (args.type_codes[i] == DLDataTypeCode::kDLOpaqueHandle) { + arg = static_cast(void_args[i])->buffer; + } else { + arg = void_args[i]; + } + OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg)); } cl_command_queue queue = w_->GetQueue(t->device); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);