Skip to content

Commit

Permalink
feat: add more cuda kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ManasviGoyal committed Jan 10, 2024
1 parent d466d70 commit 35d13e6
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ ERROR awkward_ListOffsetArray_drop_none_indexes_64(
length_offsets,
length_indexes);
}

ERROR awkward_ListOffsetArray_drop_none_indexes_32(
int32_t* tooffsets,
const int32_t* noneindexes,
Expand Down
2 changes: 2 additions & 0 deletions dev/generate-kernel-signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
"awkward_missing_repeat",
"awkward_RegularArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
"awkward_NumpyArray_reduce_adjust_starts_64",
"awkward_NumpyArray_reduce_adjust_starts_shifts_64",
"awkward_RegularArray_getitem_next_at",
Expand Down
2 changes: 2 additions & 0 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,9 @@ def gencpukerneltests(specdict):
"awkward_missing_repeat",
"awkward_RegularArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
"awkward_NumpyArray_reduce_adjust_starts_64",
"awkward_NumpyArray_reduce_adjust_starts_shifts_64",
"awkward_RegularArray_getitem_next_at",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

enum class LISTARRAY_GETITEM_NEXT_ARRAY_ADVANCED_ERRORS {
STOP_LT_START, // message: "stops[i] < starts[i]"
STOP_GET_LEN, // message: "stops[i] > len(content)"
IND_OUT_OF_RANGE, // message: "index out of range"
};

template <typename T, typename C, typename U, typename V, typename W, typename X>
__global__ void
awkward_ListArray_getitem_next_array_advanced(T* tocarry,
C* toadvanced,
const U* fromstarts,
const V* fromstops,
const W* fromarray,
const X* fromadvanced,
int64_t lenstarts,
int64_t lenarray,
int64_t lencontent,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < lenstarts) {
if (fromstops[thread_id] < fromstarts[thread_id]) {
RAISE_ERROR(LISTARRAY_GETITEM_NEXT_ARRAY_ADVANCED_ERRORS::STOP_LT_START)
}
if ((fromstarts[thread_id] != fromstops[thread_id]) &&
(fromstops[thread_id] > lencontent)) {
RAISE_ERROR(LISTARRAY_GETITEM_NEXT_ARRAY_ADVANCED_ERRORS::STOP_GET_LEN)
}
int64_t length = fromstops[thread_id] - fromstarts[thread_id];
int64_t regular_at = fromarray[fromadvanced[thread_id]];
if (regular_at < 0) {
regular_at += length;
}
if (!(0 <= regular_at && regular_at < length)) {
RAISE_ERROR(LISTARRAY_GETITEM_NEXT_ARRAY_ADVANCED_ERRORS::IND_OUT_OF_RANGE)
}
tocarry[thread_id] = fromstarts[thread_id] + regular_at;
toadvanced[thread_id] = thread_id;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

enum class LISTARRAY_GETITEM_NEXT_AT_ERRORS {
IND_OUT_OF_RANGE, // message: "index out of range"
};

template <typename T, typename C, typename U>
__global__ void
awkward_ListArray_getitem_next_at(T* tocarry,
const C* fromstarts,
const U* fromstops,
int64_t lenstarts,
int64_t at,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < lenstarts) {
int64_t length = fromstops[thread_id] - fromstarts[thread_id];
int64_t regular_at = at;
if (regular_at < 0) {
regular_at += length;
}
if (!(0 <= regular_at && regular_at < length)) {
RAISE_ERROR(LISTARRAY_GETITEM_NEXT_AT_ERRORS::IND_OUT_OF_RANGE)
}
tocarry[thread_id] = fromstarts[thread_id] + regular_at;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ awkward_ListArray_validity(const C* starts,
int64_t lencontent,
uint64_t invocation_index,
uint64_t* err_code) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id < length) {
C start = starts[thread_id];
T stop = stops[thread_id];
Expand Down

0 comments on commit 35d13e6

Please sign in to comment.