Skip to content

Commit

Permalink
Change the model to something more sane.
Browse files Browse the repository at this point in the history
  • Loading branch information
inakleinbottle committed Oct 12, 2023
1 parent 23c7059 commit 7f7f6a0
Show file tree
Hide file tree
Showing 27 changed files with 517 additions and 327 deletions.
6 changes: 3 additions & 3 deletions device/include/roughpy/device/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ enum class BufferMode
class BufferInterface : public dtl::InterfaceBase
{
public:
RPY_NO_DISCARD virtual BufferMode mode(void* content) const;
RPY_NO_DISCARD virtual BufferMode mode() const;

RPY_NO_DISCARD virtual dimn_t size(void* content) const;
RPY_NO_DISCARD virtual dimn_t size() const;

RPY_NO_DISCARD virtual void* ptr(void* content) const;
RPY_NO_DISCARD virtual void* ptr();
};

class Buffer : public dtl::ObjectBase<BufferInterface, Buffer>
Expand Down
23 changes: 11 additions & 12 deletions device/include/roughpy/device/device_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,6 @@ class RPY_EXPORT DeviceHandle
{

public:
RPY_NO_DISCARD virtual const BufferInterface*
buffer_interface() const noexcept = 0;

RPY_NO_DISCARD virtual const EventInterface*
event_interface() const noexcept = 0;

RPY_NO_DISCARD virtual const KernelInterface*
kernel_interface() const noexcept = 0;

RPY_NO_DISCARD virtual const QueueInterface*
queue_interface() const noexcept = 0;


DeviceHandle();

Expand All @@ -79,6 +67,17 @@ class RPY_EXPORT DeviceHandle
raw_alloc(dimn_t count, dimn_t alignment) const;

virtual void raw_free(Buffer buffer) const;

virtual optional<Kernel> get_kernel(string_view name) const noexcept;
virtual optional<Kernel> compile_kernel_from_str(string_view code) const;

virtual void compile_kernels_from_src(string_view code) const;


virtual Event new_event() const;
virtual Queue new_queue() const;
virtual Queue get_default_queue() const;

};


Expand Down
65 changes: 43 additions & 22 deletions device/include/roughpy/device/device_object_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@

#ifndef ROUGHPY_DEVICE_DEVICE_OBJECT_BASE_H_
#define ROUGHPY_DEVICE_DEVICE_OBJECT_BASE_H_

#include "core.h"

#include <roughpy/core/macros.h>
#include <roughpy/core/traits.h>
#include <roughpy/core/types.h>

#include <memory>

namespace rpy {
namespace device {
namespace dtl {
Expand All @@ -42,9 +45,8 @@ class RPY_EXPORT InterfaceBase
public:
virtual ~InterfaceBase();

RPY_NO_DISCARD virtual void* clone(void* content) const;

virtual void clear(void* content) const;
RPY_NO_DISCARD virtual std::unique_ptr<InterfaceBase> clone() const;
RPY_NO_DISCARD virtual Device device() const noexcept;
};

template <typename Interface, typename Derived>
Expand All @@ -55,40 +57,59 @@ class ObjectBase
"Interface must be derived from InterfaceBase"
);

friend class rpy::device::DeviceHandle;
friend class InterfaceBase;

using interface_type = Interface;

const interface_type* p_interface;
void* p_content;
static std::unique_ptr<Interface>
downcast(std::unique_ptr<InterfaceBase>&& base) noexcept
{
return {reinterpret_cast<Interface*>(base.release())};
}

public:
ObjectBase() : p_content(nullptr), p_interface(nullptr) {}

explicit ObjectBase(
const interface_type* interface,
void* content = nullptr
)
: p_interface(interface),
p_content(content)
protected:
std::unique_ptr<Interface> p_impl;

private:
explicit ObjectBase(std::unique_ptr<InterfaceBase>&& base)
: p_impl(downcast(std::move(base)))
{}

RPY_NO_DISCARD void* content() const noexcept { return p_content; }
RPY_NO_DISCARD const interface_type* interface() const noexcept
{
return p_interface;
}
public:
ObjectBase() = default;

template <
typename IFace,
typename = enable_if_t<is_base_of<Interface, IFace>::value>>
explicit ObjectBase(std::unique_ptr<IFace>&& base)
: p_impl(std::move(base))
{}

explicit ObjectBase(std::unique_ptr<Interface>&& base)
: p_impl(std::move(base))
{}

RPY_NO_DISCARD Derived clone() const;
RPY_NO_DISCARD Device device() const noexcept;
};

template <typename Interface, typename Derived>
Derived ObjectBase<Interface, Derived>::clone() const
{
RPY_CHECK(p_interface != nullptr);
return Derived(p_interface, p_interface->clone(p_content));
RPY_CHECK(p_impl);
return Derived(p_impl->clone());
}

template <typename Interface, typename Derived>
Device device::dtl::ObjectBase<Interface, Derived>::device() const noexcept
{
if (p_impl) { return p_impl->device(); }
return nullptr;
}

}// namespace dtl
}// namespace device
}// namespace rpy

#endif // ROUGHPY_DEVICE_DEVICE_OBJECT_BASE_H_
#endif// ROUGHPY_DEVICE_DEVICE_OBJECT_BASE_H_
4 changes: 2 additions & 2 deletions device/include/roughpy/device/event.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ namespace device {
class RPY_EXPORT EventInterface : public dtl::InterfaceBase
{
public:
virtual void wait(void* content) const;
virtual void wait();

virtual EventStatus status(void* content) const;
virtual EventStatus status() const;
};

class Event : public dtl::ObjectBase<EventInterface, Event>
Expand Down
14 changes: 7 additions & 7 deletions device/include/roughpy/device/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,25 @@ class RPY_EXPORT KernelInterface : public dtl::InterfaceBase
{

public:
RPY_NO_DISCARD virtual string_view name(void* content) const;
RPY_NO_DISCARD virtual string_view name() const;

RPY_NO_DISCARD virtual dimn_t num_args(void* content) const;
RPY_NO_DISCARD virtual dimn_t num_args() const;

RPY_NO_DISCARD virtual Event launch_kernel_async(
void* content,
Queue& queue,
Slice<void*> args,
Slice<dimn_t> arg_sizes,
const KernelLaunchParams& params
) const;

);
};

class RPY_EXPORT Kernel : public dtl::ObjectBase<KernelInterface, Kernel>
{
using base_t = dtl::ObjectBase<KernelInterface, Kernel>;

public:
using base_t::base_t;

RPY_NO_DISCARD string_view name() const;

RPY_NO_DISCARD dimn_t num_args() const;
Expand All @@ -87,8 +88,7 @@ class RPY_EXPORT Kernel : public dtl::ObjectBase<KernelInterface, Kernel>
Slice<void*> args,
Slice<dimn_t> arg_sizes,
const KernelLaunchParams& params
);

);
};

}// namespace device
Expand Down
2 changes: 1 addition & 1 deletion device/include/roughpy/device/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class QueueInterface : public dtl::InterfaceBase
{
public:

virtual dimn_t size(void* content) const;
virtual dimn_t size() const;

};

Expand Down
12 changes: 6 additions & 6 deletions device/src/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,22 @@ using namespace rpy::device;


BufferMode Buffer::mode() const {
if (interface() == nullptr || content() == nullptr) {
if (!p_impl){
return BufferMode::Read;
}
return interface()->mode(content());
return p_impl->mode();
}

dimn_t Buffer::size() const {
if (interface() == nullptr || content() == nullptr) {
if (!p_impl) {
return 0;
}
return interface()->size(content());
return p_impl->size();
}

void* Buffer::ptr() {
if (interface() == nullptr || content() == nullptr) {
if (!p_impl) {
return nullptr;
}
return interface()->ptr(content());
return p_impl->ptr();
}
6 changes: 3 additions & 3 deletions device/src/buffer_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ using namespace rpy;
using namespace rpy::device;


BufferMode BufferInterface::mode(void* content) const {
BufferMode BufferInterface::mode() const {
return BufferMode::Read;
}

dimn_t BufferInterface::size(void* content) const {
dimn_t BufferInterface::size() const {
return 0;
}

void* BufferInterface::ptr(void* content) const {
void* BufferInterface::ptr() {
return nullptr;
}
12 changes: 9 additions & 3 deletions device/src/device_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@

#include <roughpy/device/buffer.h>


#include <mutex>
#include <vector>

using namespace rpy;
using namespace rpy::device;

Expand All @@ -45,12 +49,14 @@ optional<fs::path> DeviceHandle::runtime_library() const noexcept

DeviceHandle::~DeviceHandle() = default;


Buffer DeviceHandle::raw_alloc(rpy::dimn_t count, rpy::dimn_t alignment) const
{
return Buffer{nullptr, nullptr};
return {};
}

void DeviceHandle::raw_free(Buffer buffer) const {
void DeviceHandle::raw_free(Buffer buffer) const {}

optional<Kernel> DeviceHandle::get_kernel(string_view name) const noexcept
{
return {};
}
11 changes: 8 additions & 3 deletions device/src/device_interface_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,20 @@

#include <roughpy/device/device_object_base.h>

#include "device_handle.h"

using namespace rpy;
using namespace rpy::device;

rpy::device::dtl::InterfaceBase::~InterfaceBase() = default;

void* rpy::device::dtl::InterfaceBase::clone(void* content) const
std::unique_ptr<rpy::device::dtl::InterfaceBase>
rpy::device::dtl::InterfaceBase::clone() const
{
return nullptr;
}

void rpy::device::dtl::InterfaceBase::clear(void* RPY_UNUSED_VAR content) const
{}
Device device::dtl::InterfaceBase::device() const noexcept
{
return Device(nullptr);
}
8 changes: 4 additions & 4 deletions device/src/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ using namespace rpy;
using namespace rpy::device;

void Event::wait() {
if (interface() != nullptr && content() != nullptr) {
interface()->wait(content());
if (p_impl) {
p_impl->wait();
}
}


EventStatus Event::status() const {
if (interface() == nullptr || content() == nullptr) {
if (!p_impl) {
return EventStatus::CompletedSuccessfully;
}
return interface()->status(content());
return p_impl->status();
}
4 changes: 2 additions & 2 deletions device/src/event_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ using namespace rpy;
using namespace rpy::device;


void EventInterface::wait(void* RPY_UNUSED_VAR content) const
void EventInterface::wait()
{
}

EventStatus EventInterface::status(void* RPY_UNUSED_VAR content) const {
EventStatus EventInterface::status() const {
return EventStatus::CompletedSuccessfully;
}
Loading

0 comments on commit 7f7f6a0

Please sign in to comment.