Skip to content

Commit

Permalink
Added basic launch functions for kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
inakleinbottle committed Oct 11, 2023
1 parent 1021ca0 commit eca7035
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 41 deletions.
78 changes: 37 additions & 41 deletions device/include/roughpy/device/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,73 +29,69 @@
#define ROUGHPY_DEVICE_KERNEL_H_

#include "core.h"
#include "event.h"
#include "device_object_base.h"
#include "event.h"

#include <roughpy/core/macros.h>
#include <roughpy/core/types.h>
#include <roughpy/core/slice.h>
#include <roughpy/core/types.h>

namespace rpy {
namespace device {



namespace rpy { namespace device {


class KernelLaunchParams {
class KernelLaunchParams
{
Dim3 m_work_size;
Dim3 m_group_size;
optional<Dim3> m_offsets;

public:

KernelLaunchParams();

};

class RPY_EXPORT KernelInterface : public dtl::InterfaceBase {

class RPY_EXPORT KernelInterface : public dtl::InterfaceBase
{

public:
RPY_NO_DISCARD virtual string_view name(void* content) const;

RPY_NO_DISCARD
virtual string_view name(void* content) const;

RPY_NO_DISCARD
virtual dimn_t num_args(void* content) const;

RPY_NO_DISCARD
virtual Event launch_kernel_async(void* content,
Queue& queue,
Slice<void*> args,
Slice<dimn_t> arg_sizes,
const KernelLaunchParams& params) const;



virtual EventStatus launch_kernel_sync(void* content,
Queue& queue,
Slice<void*> args,
Slice<dimn_t> arg_sizes,
const KernelLaunchParams& params) const;

RPY_NO_DISCARD virtual dimn_t num_args(void* content) const;

RPY_NO_DISCARD virtual Event launch_kernel_async(
void* content,
Queue& queue,
Slice<void*> args,
Slice<dimn_t> arg_sizes,
const KernelLaunchParams& params
) const;

};


class RPY_EXPORT Kernel : public dtl::ObjectBase<KernelInterface, Kernel> {
class RPY_EXPORT Kernel : public dtl::ObjectBase<KernelInterface, Kernel>
{

public:
RPY_NO_DISCARD string_view name() const;

RPY_NO_DISCARD dimn_t num_args() const;

RPY_NO_DISCARD
string_view name() const;
RPY_NO_DISCARD Event launch_async(
Queue& queue,
Slice<void*> args,
Slice<dimn_t> arg_sizes,
const KernelLaunchParams& params
);

RPY_NO_DISCARD
dimn_t num_args() const;
RPY_NO_DISCARD EventStatus launch_sync(
Queue& queue,
Slice<void*> args,
Slice<dimn_t> arg_sizes,
const KernelLaunchParams& params
);

};

}}
}// namespace device
}// namespace rpy

#endif // ROUGHPY_DEVICE_KERNEL_H_
#endif// ROUGHPY_DEVICE_KERNEL_H_
31 changes: 31 additions & 0 deletions device/src/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,34 @@ dimn_t Kernel::num_args() const
}
return interface()->num_args(content());
}

Event Kernel::launch_async(
Queue& queue,
Slice<void*> args,
Slice<dimn_t> arg_sizes,
const KernelLaunchParams& params
)
{
if (interface() == nullptr || content() == nullptr) {
return Event(nullptr, nullptr);
}

auto nargs = interface()->num_args(content());
if (nargs != args.size() || nargs != arg_sizes.size()) {
RPY_THROW(std::runtime_error, "incorrect number of arguments provided");
}

return interface()->launch_kernel_async(content(), queue, args,
arg_sizes, params);
}
EventStatus Kernel::launch_sync(
Queue& queue,
Slice<void*> args,
Slice<dimn_t> arg_sizes,
const KernelLaunchParams& params
)
{
auto event = launch_async(queue, args, arg_sizes, params);
event.wait();
return event.status();
}

0 comments on commit eca7035

Please sign in to comment.