Skip to content

Commit

Permalink
Many improvements to device support
Browse files Browse the repository at this point in the history
  • Loading branch information
inakleinbottle committed Oct 23, 2023
1 parent 967e0eb commit b687835
Show file tree
Hide file tree
Showing 20 changed files with 453 additions and 119 deletions.
18 changes: 16 additions & 2 deletions device/include/roughpy/device/device_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@
namespace rpy {
namespace devices {

struct ExtensionSourceAndOptions {
std::vector<string> sources;
string compile_options;
std::vector<pair<string, string>> header_name_and_source;
string link_options;
};



/**
* @brief Interface for interacting with compute devices.
*
Expand All @@ -69,6 +78,7 @@ class RPY_EXPORT DeviceHandle

virtual ~DeviceHandle();

RPY_NO_DISCARD virtual DeviceType type() const noexcept;
RPY_NO_DISCARD virtual DeviceCategory category() const noexcept;

RPY_NO_DISCARD virtual DeviceInfo info() const noexcept;
Expand All @@ -85,14 +95,18 @@ class RPY_EXPORT DeviceHandle

virtual void raw_free(void* pointer, dimn_t size) const;

virtual bool has_compiler() const noexcept;

virtual const Kernel& register_kernel(Kernel kernel) const;

RPY_NO_DISCARD
virtual optional<Kernel> get_kernel(const string& name) const noexcept;
RPY_NO_DISCARD
virtual optional<Kernel> compile_kernel_from_str(string_view code) const;
virtual optional<Kernel>
compile_kernel_from_str(const ExtensionSourceAndOptions& args) const;

virtual void compile_kernels_from_src(string_view code) const;
virtual void compile_kernels_from_src(const ExtensionSourceAndOptions& args
) const;

RPY_NO_DISCARD
virtual Event new_event() const;
Expand Down
10 changes: 9 additions & 1 deletion device/include/roughpy/device/device_object_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class RPY_EXPORT InterfaceBase
public:
virtual ~InterfaceBase();

RPY_NO_DISCARD virtual DeviceType type() const noexcept;
RPY_NO_DISCARD virtual dimn_t ref_count() const noexcept;
RPY_NO_DISCARD virtual std::unique_ptr<InterfaceBase> clone() const;
RPY_NO_DISCARD virtual Device device() const noexcept;
Expand Down Expand Up @@ -98,7 +99,7 @@ class ObjectBase
ObjectBase& operator=(const ObjectBase& other);
ObjectBase& operator=(ObjectBase&& other) noexcept = default;


RPY_NO_DISCARD DeviceType type() const noexcept;
RPY_NO_DISCARD bool is_null() const noexcept { return !p_impl; }
RPY_NO_DISCARD dimn_t ref_count() const noexcept;
RPY_NO_DISCARD Derived clone() const;
Expand Down Expand Up @@ -128,6 +129,13 @@ ObjectBase<Interface, Derived>::operator=(const ObjectBase& other)
return *this;
}

template <typename Interface, typename Derived>
DeviceType ObjectBase<Interface, Derived>::type() const noexcept
{
if (p_impl) { return p_impl->type(); }
return DeviceType::CPU;
}

template <typename Interface, typename Derived>
dimn_t ObjectBase<Interface, Derived>::ref_count() const noexcept
{
Expand Down
16 changes: 12 additions & 4 deletions device/src/cpu/cpu_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,21 @@ optional<Kernel> CPUDeviceHandle::get_kernel(const string& name) const noexcept

return {};
}
optional<Kernel> CPUDeviceHandle::compile_kernel_from_str(string_view code
optional<Kernel>
CPUDeviceHandle::compile_kernel_from_str(const ExtensionSourceAndOptions& args
) const
{
if (p_ocl_handle) {
return p_ocl_handle->compile_kernel_from_str(code);
return p_ocl_handle->compile_kernel_from_str(args);
}
return {};
}
void CPUDeviceHandle::compile_kernels_from_src(string_view code) const
void CPUDeviceHandle::compile_kernels_from_src(
const ExtensionSourceAndOptions& args
) const
{
if (p_ocl_handle) {
p_ocl_handle->compile_kernels_from_src(code);
p_ocl_handle->compile_kernels_from_src(args);
}
}
Event CPUDeviceHandle::new_event() const {
Expand All @@ -262,3 +265,8 @@ DeviceCategory CPUDeviceHandle::category() const noexcept
{
return DeviceCategory::CPU;
}
bool CPUDeviceHandle::has_compiler() const noexcept
{
if (p_ocl_handle) { return p_ocl_handle->has_compiler(); }
return false;
}
9 changes: 7 additions & 2 deletions device/src/cpu/cpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,14 @@ class CPUDeviceHandle : public DeviceHandle
DeviceInfo info() const noexcept override;
Buffer raw_alloc(dimn_t count, dimn_t alignment) const override;
void raw_free(void* pointer, dimn_t size) const override;

bool has_compiler() const noexcept override;
optional<Kernel> get_kernel(const string& name) const noexcept override;
optional<Kernel> compile_kernel_from_str(string_view code) const override;
void compile_kernels_from_src(string_view code) const override;
optional<Kernel>
compile_kernel_from_str(const ExtensionSourceAndOptions& args
) const override;
void compile_kernels_from_src(const ExtensionSourceAndOptions& args
) const override;
Event new_event() const override;
Queue new_queue() const override;
Queue get_default_queue() const override;
Expand Down
10 changes: 8 additions & 2 deletions device/src/device_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ Buffer DeviceHandle::raw_alloc(rpy::dimn_t count, rpy::dimn_t alignment) const

void DeviceHandle::raw_free(void* pointer, dimn_t size) const {}

bool DeviceHandle::has_compiler() const noexcept { return false; }

const Kernel& DeviceHandle::register_kernel(Kernel kernel) const {
RPY_CHECK(kernel.device() == this);
const guard_type access(get_lock());
Expand All @@ -85,11 +87,14 @@ optional<Kernel> DeviceHandle::get_kernel(const string& name) const noexcept
return {};
}

optional<Kernel> DeviceHandle::compile_kernel_from_str(string_view code) const
optional<Kernel>
DeviceHandle::compile_kernel_from_str(const ExtensionSourceAndOptions& args
) const
{
return {};
}
void DeviceHandle::compile_kernels_from_src(string_view RPY_UNUSED_VAR code
void DeviceHandle::compile_kernels_from_src(
const ExtensionSourceAndOptions& args
) const
{}
Event DeviceHandle::new_event() const { return {}; }
Expand All @@ -108,3 +113,4 @@ bool DeviceHandle::supports_type(const TypeInfo& info) const noexcept
{
return false;
}
DeviceType DeviceHandle::type() const noexcept { return DeviceType::CPU; }
4 changes: 4 additions & 0 deletions device/src/device_interface_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ rpy::devices::dtl::InterfaceBase::~InterfaceBase() = default;

dimn_t devices::dtl::InterfaceBase::ref_count() const noexcept { return 1; }

DeviceType devices::dtl::InterfaceBase::type() const noexcept {
return DeviceType::CPU;
}

std::unique_ptr<rpy::devices::dtl::InterfaceBase>
rpy::devices::dtl::InterfaceBase::clone() const
{
Expand Down
3 changes: 3 additions & 0 deletions device/src/opencl/ocl_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,6 @@ dimn_t OCLBuffer::ref_count() const noexcept
}
return 0;
}
DeviceType OCLBuffer::type() const noexcept {
return DeviceType::OpenCL;
}
2 changes: 2 additions & 0 deletions device/src/opencl/ocl_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class OCLBuffer : public BufferInterface
OCLBuffer(cl_mem buffer, OCLDevice dev) noexcept;
~OCLBuffer() override;

RPY_NO_DISCARD
DeviceType type() const noexcept override;
RPY_NO_DISCARD
dimn_t ref_count() const noexcept override;
RPY_NO_DISCARD
Expand Down
6 changes: 2 additions & 4 deletions device/src/opencl/ocl_decls.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ class OCLQueue;

using OCLDevice = boost::intrusive_ptr<const OCLDeviceHandle>;


}
}

}// namespace devices
}// namespace rpy

#endif// ROUGHPY_DEVICE_SRC_OPENCL_OCL_DECLS_H_
Loading

0 comments on commit b687835

Please sign in to comment.