Skip to content

Commit

Permalink
Extend sum reduction kernel to axis1 reduction on 2d arrays and add 1…
Browse files Browse the repository at this point in the history
…d argmin reduction kernel (#46)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
  • Loading branch information
3 people authored Nov 15, 2022
1 parent c07601b commit cc8c089
Show file tree
Hide file tree
Showing 4 changed files with 565 additions and 135 deletions.
7 changes: 7 additions & 0 deletions sklearn_numba_dpex/common/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import math


def check_power_of_2(e):
if e != 2 ** (math.log2(e)):
raise ValueError(f"Expected a power of 2, got {e}")
return e
263 changes: 240 additions & 23 deletions sklearn_numba_dpex/common/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import dpctl.tensor as dpt
import numba_dpex as dpex

from sklearn_numba_dpex.common._utils import check_power_of_2


zero_idx = np.int64(0)

Expand Down Expand Up @@ -130,10 +132,12 @@ def half_l2_norm(


@lru_cache
def make_sum_reduction_1d_kernel(size, work_group_size, device, dtype):
"""numba_dpex does not provide tools such as `cuda.reduce` so we implement from
scratch a reduction strategy. The strategy relies on the commutativity of the
operation used for the reduction, thus allowing to reduce the input in any order.
def make_sum_reduction_2d_axis1_kernel(size0, size1, work_group_size, device, dtype):
"""Implement data_2d.sum(axis=1) or data_1d.sum()
numba_dpex does not provide tools such as `cuda.reduce` so we implement from scratch
a reduction strategy. The strategy relies on the commutativity of the operation used
for the reduction, thus allowing to reduce the input in any order.
The strategy consists in performing local reductions in each work group using local
memory where each work item combine two values, thus halving the number of values,
Expand All @@ -147,66 +151,163 @@ def make_sum_reduction_1d_kernel(size, work_group_size, device, dtype):
`2 * work_group_size`. This is repeated as many time as needed until only one value
remains in global memory.
NB: work_group_size is assumed to be a power of 2.
Notes
-----
`work_group_size` is assumed to be a power of 2.
if `size1` is None then the kernel expects 1d tensor inputs. If `size1` is not None
then the expected shape of input tensors is `(size0, size1)`, and the reduction
operation is equivalent to input.sum(axis=1). In this case, the kernel is a good
choice if `size1` >> `preferred_work_group_size_multiple`, and if `size0` ranges in
the same order of magnitude than `preferred_work_group_size_multiple`. If not,
other reduction implementations might give better performances.
"""
check_power_of_2(work_group_size)

# Number of iteration in each execution of the kernel:
local_n_iterations = np.int64(math.floor(math.log2(work_group_size)) - 1)

zero = dtype(0.0)
two_long = np.int64(2)
one_idx = np.int64(1)

minus_one_idx = np.int64(-1)
two_as_a_long = np.int64(2)

is_1d = size1 is None
# TODO: this set of kernel functions could be abstracted away to other coalescing
# functions
if is_1d:
sum_axis_size = size0
n_rows = np.int64(1)
local_values_size = work_group_size

@dpex.func
def set_col_to_zero(array, i):
array[i] = zero

@dpex.func
def copy_col(from_array, from_col, to_array, to_col):
to_array[to_col] = from_array[from_col]

@dpex.func
def add_cols(
from_array,
left_from_col,
right_from_col,
to_array,
to_col,
):
to_array[to_col] = from_array[left_from_col] + from_array[right_from_col]

@dpex.func
def add_cols_inplace(
array,
from_col,
to_col,
):
array[to_col] += array[from_col]

@dpex.func
def add_first_cols(from_array, to_array, to_col):
to_array[to_col] = from_array[zero_idx] + from_array[one_idx]

else:
sum_axis_size = size1
n_rows = size0
local_values_size = (size0, work_group_size)

@dpex.func
def set_col_to_zero(array, i):
for row in range(n_rows):
array[row, i] = zero

@dpex.func
def copy_col(from_array, from_col, to_array, to_col):
for row in range(n_rows):
to_array[row, to_col] = from_array[row, from_col]

@dpex.func
def add_cols(
from_array,
left_from_col,
right_from_col,
to_array,
to_col,
):
for row in range(n_rows):
to_array[row, to_col] = (
from_array[row, left_from_col] + from_array[row, right_from_col]
)

@dpex.func
def add_cols_inplace(
array,
from_col,
to_col,
):
for row in range(n_rows):
array[row, to_col] += array[row, from_col]

@dpex.func
def add_first_cols(from_array, to_array, to_col):
for row in range(n_rows):
to_array[row, to_col] = (
from_array[row, zero_idx] + from_array[row, one_idx]
)

# Optimized for C-contiguous array where the size of the sum axis is
# >> preferred_work_group_size_multiple, and the size of the other axis (if any) is
# is smaller or similar to preferred_work_group_size_multiple.
# ???: how does this strategy compares to having each thread reducing N contiguous
# items ?
@dpex.kernel
# fmt: off
def partial_sum_reduction(
summands, # IN (size,)
result, # OUT (math.ceil(size / (2 * work_group_size),)
summands, # IN (n_rows, sum_axis_size)
result, # OUT (n_rows, math.ceil(size / (2 * work_group_size),)
):
# fmt: on
# NB: This kernel only perform a partial reduction
group_id = dpex.get_group_id(zero_idx)
local_work_id = dpex.get_local_id(zero_idx)
first_work_id = local_work_id == zero_idx

size = summands.shape[zero_idx]
size = summands.shape[minus_one_idx]

local_data = dpex.local.array(work_group_size, dtype=dtype)
local_values = dpex.local.array(local_values_size, dtype=dtype)

first_value_idx = group_id * work_group_size * two_long
first_value_idx = group_id * work_group_size * two_as_a_long
augend_idx = first_value_idx + local_work_id
addend_idx = first_value_idx + work_group_size + local_work_id

# Each work item reads two value in global memory and sum it into the local
# memory
if augend_idx >= size:
local_data[local_work_id] = zero
set_col_to_zero(local_values, local_work_id)
elif addend_idx >= size:
local_data[local_work_id] = summands[augend_idx]
copy_col(summands, augend_idx, local_values, local_work_id)
else:
local_data[local_work_id] = summands[augend_idx] + summands[addend_idx]
add_cols(summands, augend_idx, addend_idx, local_values, local_work_id)

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)
current_n_work_items = work_group_size
for i in range(local_n_iterations):
# We discard half of the remaining active work items at each iteration
current_n_work_items = current_n_work_items // two_long
current_n_work_items = current_n_work_items // two_as_a_long
if local_work_id < current_n_work_items:
local_data[local_work_id] += local_data[
local_work_id + current_n_work_items
]
add_cols_inplace(local_values, local_work_id + current_n_work_items, local_work_id)

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)

# At this point local_data[0] = local_data[1] is equal to the sum of all
# At this point local_values[0] + local_values[1] is equal to the sum of all
# elements in summands that have been covered by the work group, we write it
# into global memory
if first_work_id:
result[group_id] = local_data[zero_idx] + local_data[one_idx]
add_first_cols(local_values, result, group_id)

# As many partial reductions as necessary are chained until only one element
# remains.
kernels_and_empty_tensors_pairs = []
n_groups = size
n_groups = sum_axis_size
# TODO: at some point, the cost of scheduling the kernel is more than the cost of
# running the reduction iteration. At this point the loop should stop and then a
# single work item should iterates one time on the remaining values to finish the
Expand All @@ -215,13 +316,129 @@ def partial_sum_reduction(
n_groups = math.ceil(n_groups / (2 * work_group_size))
global_size = n_groups * work_group_size
kernel = partial_sum_reduction[global_size, work_group_size]
result = dpt.empty(n_groups, dtype=dtype, device=device)
result_shape = n_groups if is_1d else (n_rows, n_groups)
# NB: here memory for partial results is allocated ahead of time and will only
# be garbage collected when the instance of `sum_reduction` is garbage
# collected. Thus it can be more efficient to re-use a same instance of
# `sum_reduction` (e.g within iterations of a loop) since it avoid deallocation
# and reallocation every time.
result = dpt.empty(result_shape, dtype=dtype, device=device)
kernels_and_empty_tensors_pairs.append((kernel, result))

def sum_reduction(summands):
# TODO: manually dispatch the kernels with a SyclQueue
for kernel, result in kernels_and_empty_tensors_pairs:
kernel(summands, result)
summands = result
return result

return sum_reduction


@lru_cache
def make_argmin_reduction_1d_kernel(size, work_group_size, device, dtype):
"""Implement 1d argmin with the same strategy than for make_sum_reduction_2d_axis1_kernel."""
check_power_of_2(work_group_size)

# Number of iteration in each execution of the kernel:
local_n_iterations = np.int64(math.floor(math.log2(work_group_size)) - 1)

two_as_a_long = np.int64(2)
one_idx = np.int64(1)
inf = dtype(np.inf)

# TODO: the first call of partial_argmin_reduction in the final loop should be
# written with only two arguments since "previous_result" does not exist yet.
# It seems it's not possible to get a good factoring of the code to avoid copying
# most of the code for this with @dpex.kernel, for now we resort to branching.
@dpex.kernel
# fmt: off
def partial_argmin_reduction(
values, # IN (size,)
previous_result, # IN (current_size,)
argmin_indices, # OUT (math.ceil(
# (current_size if current_size else size)
# / (2 * work_group_size),)
# ))
):
# fmt: on
group_id = dpex.get_group_id(zero_idx)
local_work_id = dpex.get_local_id(zero_idx)
first_work_id = local_work_id == zero_idx

previous_result_size = previous_result.shape[zero_idx]
has_previous_result = previous_result_size > one_idx
current_size = previous_result_size if has_previous_result else values.shape[zero_idx]

local_argmin = dpex.local.array(work_group_size, dtype=np.int32)
local_values = dpex.local.array(work_group_size, dtype=dtype)

first_value_idx = group_id * work_group_size * two_as_a_long
x_idx = first_value_idx + local_work_id
y_idx = first_value_idx + work_group_size + local_work_id

if x_idx >= current_size:
local_values[local_work_id] = inf
else:
if has_previous_result:
x_idx = previous_result[x_idx]

if y_idx >= current_size:
local_argmin[local_work_id] = x_idx
local_values[local_work_id] = values[x_idx]

else:
if has_previous_result:
y_idx = previous_result[y_idx]

x = values[x_idx]
y = values[y_idx]
if x < y or (x == y and x_idx < y_idx):
local_argmin[local_work_id] = x_idx
local_values[local_work_id] = x
else:
local_argmin[local_work_id] = y_idx
local_values[local_work_id] = y

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)
current_n_work_items = work_group_size
for i in range(local_n_iterations):
current_n_work_items = current_n_work_items // two_as_a_long
if local_work_id < current_n_work_items:
local_x_idx = local_work_id
local_y_idx = local_work_id + current_n_work_items

x= local_values[local_x_idx]
y= local_values[local_y_idx]

if x> y:
local_values[local_x_idx] = y
local_argmin[local_x_idx] = local_argmin[local_y_idx]

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)

if first_work_id:
if local_values[zero_idx] <= local_values[one_idx]:
argmin_indices[group_id] = local_argmin[zero_idx]
else:
argmin_indices[group_id] = local_argmin[one_idx]

# As many partial reductions as necessary are chained until only one element
# remains.argmin_indices
kernels_and_empty_tensors_tuples = []
n_groups = size
previous_result = dpt.empty((1,), dtype=np.int32, device=device)
while n_groups > 1:
n_groups = math.ceil(n_groups / (2 * work_group_size))
global_size = n_groups * work_group_size
kernel = partial_argmin_reduction[global_size, work_group_size]
result = dpt.empty(n_groups, dtype=np.int32, device=device)
kernels_and_empty_tensors_tuples.append((kernel, previous_result, result))
previous_result = result

def argmin_reduction(values):
for kernel, previous_result, result in kernels_and_empty_tensors_tuples:
kernel(values, previous_result, result)
return result

return argmin_reduction
Loading

0 comments on commit cc8c089

Please sign in to comment.