Skip to content

Commit

Permalink
Extend sum reduction kernel to axis1 reduction on 2d arrays and add 1…
Browse files Browse the repository at this point in the history
…d argmin reduction kernel
  • Loading branch information
fcharras committed Oct 31, 2022
1 parent 3d1bf13 commit ee0897f
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 22 deletions.
233 changes: 218 additions & 15 deletions sklearn_numba_dpex/common/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def half_l2_norm(


@lru_cache
def make_sum_reduction_1d_kernel(size, work_group_size, device, dtype):
def make_sum_reduction_2d_axis1_kernel(size0, size1, work_group_size, device, dtype):
"""numba_dpex does not provide tools such as `cuda.reduce` so we implement from
scratch a reduction strategy. The strategy relies on the commutativity of the
operation used for the reduction, thus allowing to reduce the input in any order.
Expand All @@ -148,6 +148,16 @@ def make_sum_reduction_1d_kernel(size, work_group_size, device, dtype):
remains in global memory.
NB: work_group_size is assumed to be a power of 2.
NB: if size1 is None then the kernel expects 1d tensor inputs. If size1 is not None
then the expected shape of input tensors is (size0, size1), and the reduction
operation is equivalent to input.sum(axis=1). In this case, the kernel is a good
choice if size1 >> preferred_work_group_size_multiple, and if size0 ranges in the
same order of magnitude than preferred_work_group_size_multiple. If not, other
reduction implementations might give better performances.
???: how does this strategy compares to having each thread reducing N contiguous
items ?
"""
# Number of iteration in each execution of the kernel:
local_n_iterations = np.int64(math.floor(math.log2(work_group_size)) - 1)
Expand All @@ -156,21 +166,107 @@ def make_sum_reduction_1d_kernel(size, work_group_size, device, dtype):
two_long = np.int64(2)
one_idx = np.int64(1)

is_1d = size1 is None
if is_1d:
sum_axis_size = size0
n_rows = np.int64(1)
local_data_size = work_group_size

@dpex.func
def set_col_to_zero(array, i):
array[i] = zero

@dpex.func
def copy_col(from_array, from_col, to_array, to_col):
to_array[to_col] = from_array[from_col]

@dpex.func
def coalesce_cols(
from_array,
left_from_col,
right_from_col,
to_array,
to_col,
):
to_array[to_col] = from_array[left_from_col] + from_array[right_from_col]

@dpex.func
def coalesce_cols_inplace(
array,
from_col,
to_col,
):
array[to_col] += array[from_col]

@dpex.func
def coalesce_first_cols(from_array, to_array, to_col):
to_array[to_col] = from_array[zero_idx] + from_array[one_idx]

else:
sum_axis_size = size1
n_rows = size0
local_data_size = (size0, work_group_size)

@dpex.func
def set_col_to_zero(array, i):
for row in range(n_rows):
array[row, i] = zero

@dpex.func
def copy_col(from_array, from_col, to_array, to_col):
for row in range(n_rows):
to_array[row, to_col] = from_array[row, from_col]

@dpex.func
def coalesce_cols(
from_array,
left_from_col,
right_from_col,
to_array,
to_col,
):
for row in range(n_rows):
to_array[row, to_col] = (
from_array[row, left_from_col] + from_array[row, right_from_col]
)

@dpex.func
def coalesce_cols_inplace(
array,
from_col,
to_col,
):
for row in range(n_rows):
array[row, to_col] += array[row, from_col]

@dpex.func
def coalesce_first_cols(from_array, to_array, to_col):
for row in range(n_rows):
to_array[row, to_col] = (
from_array[row, zero_idx] + from_array[row, one_idx]
)

two_long = np.int64(2)
m_one_idx = np.int64(-1)

# Optimized for C-contiguous array where the size of the sum axis is
# >> preferred_work_group_size_multiple, and the size of the other axis (if any) is
# is smaller or similar to preferred_work_group_size_multiple.
@dpex.kernel
# fmt: off
def partial_sum_reduction(
summands, # IN (size,)
result, # OUT (math.ceil(size / (2 * work_group_size),)
summands, # IN (n_rows, sum_axis_size)
result, # OUT (n_rows, math.ceil(size / (2 * work_group_size),)
):
# fmt: on
# NB: This kernel only perform a partial reduction
group_id = dpex.get_group_id(zero_idx)
local_work_id = dpex.get_local_id(zero_idx)
first_work_id = local_work_id == zero_idx

size = summands.shape[zero_idx]
size = summands.shape[m_one_idx]

local_data = dpex.local.array(work_group_size, dtype=dtype)
local_data = dpex.local.array(local_data_size, dtype=dtype)

first_value_idx = group_id * work_group_size * two_long
augend_idx = first_value_idx + local_work_id
Expand All @@ -179,34 +275,32 @@ def partial_sum_reduction(
# Each work item reads two value in global memory and sum it into the local
# memory
if augend_idx >= size:
local_data[local_work_id] = zero
set_col_to_zero(local_data, local_work_id)
elif addend_idx >= size:
local_data[local_work_id] = summands[augend_idx]
copy_col(summands, augend_idx, local_data, local_work_id)
else:
local_data[local_work_id] = summands[augend_idx] + summands[addend_idx]
coalesce_cols(summands, augend_idx, addend_idx, local_data, local_work_id)

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)
current_n_work_items = work_group_size
for i in range(local_n_iterations):
# We discard half of the remaining active work items at each iteration
current_n_work_items = current_n_work_items // two_long
if local_work_id < current_n_work_items:
local_data[local_work_id] += local_data[
local_work_id + current_n_work_items
]
coalesce_cols_inplace(local_data, local_work_id + current_n_work_items, local_work_id)

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)

# At this point local_data[0] = local_data[1] is equal to the sum of all
# At this point local_data[0] + local_data[1] is equal to the sum of all
# elements in summands that have been covered by the work group, we write it
# into global memory
if first_work_id:
result[group_id] = local_data[zero_idx] + local_data[one_idx]
coalesce_first_cols(local_data, result, group_id)

# As many partial reductions as necessary are chained until only one element
# remains.
kernels_and_empty_tensors_pairs = []
n_groups = size
n_groups = sum_axis_size
# TODO: at some point, the cost of scheduling the kernel is more than the cost of
# running the reduction iteration. At this point the loop should stop and then a
# single work item should iterates one time on the remaining values to finish the
Expand All @@ -215,13 +309,122 @@ def partial_sum_reduction(
n_groups = math.ceil(n_groups / (2 * work_group_size))
global_size = n_groups * work_group_size
kernel = partial_sum_reduction[global_size, work_group_size]
result = dpt.empty(n_groups, dtype=dtype, device=device)
result_shape = n_groups if is_1d else (n_rows, n_groups)
result = dpt.empty(result_shape, dtype=dtype, device=device)
kernels_and_empty_tensors_pairs.append((kernel, result))

def sum_reduction(summands):
# TODO: manually dispatch the kernels with a SyclQueue
for kernel, result in kernels_and_empty_tensors_pairs:
kernel(summands, result)
summands = result
return result

return sum_reduction


@lru_cache
def make_argmin_reduction_1d_kernel(size, work_group_size, device, dtype):
"""Implement 1d argmin with the same strategy than for make_sum_reduction_2d_axis1_kernel."""
# Number of iteration in each execution of the kernel:
local_n_iterations = np.int64(math.floor(math.log2(work_group_size)) - 1)

two_long = np.int64(2)
one_idx = np.int64(1)
inf = dtype(np.inf)

# TODO: the first call of partial_argmin_reduction in the final loop should be
# written with only two arguments since "previous_result" does not exist yet.
# It seems it's not possible to get a good factoring of the code to avoid copying
# most of the code for this with @dpex.kernel, for now we resort to branching.
@dpex.kernel
# fmt: off
def partial_argmin_reduction(
data, # IN (size,)
previous_result, # IN (current_size,)
result, # OUT (math.ceil(
# (current_size if current_size else size)
# / (2 * work_group_size),)
# ))
):
# fmt: on
group_id = dpex.get_group_id(zero_idx)
local_work_id = dpex.get_local_id(zero_idx)
first_work_id = local_work_id == zero_idx

previous_result_size = previous_result.shape[zero_idx]
has_previous_result = previous_result_size > one_idx
current_size = previous_result_size if has_previous_result else data.shape[zero_idx]

local_argmin = dpex.local.array(work_group_size, dtype=np.int32)
local_data = dpex.local.array(work_group_size, dtype=dtype)

first_value_idx = group_id * work_group_size * two_long
x_idx = first_value_idx + local_work_id
y_idx = first_value_idx + work_group_size + local_work_id

if x_idx >= current_size:
local_data[local_work_id] = inf
else:
if has_previous_result:
x_idx = previous_result[x_idx]

if y_idx >= current_size:
local_argmin[local_work_id] = x_idx
local_data[local_work_id] = data[x_idx]

else:
if has_previous_result:
y_idx = previous_result[y_idx]

x_data = data[x_idx]
y_data = data[y_idx]
if x_data <= y_data:
local_argmin[local_work_id] = x_idx
local_data[local_work_id] = x_data
else:
local_argmin[local_work_id] = y_idx
local_data[local_work_id] = y_data

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)
current_n_work_items = work_group_size
for i in range(local_n_iterations):
current_n_work_items = current_n_work_items // two_long
if local_work_id < current_n_work_items:
local_x_idx = local_work_id
local_y_idx = local_work_id + current_n_work_items

x_data = local_data[local_x_idx]
y_data = local_data[local_y_idx]

if x_data > y_data:
local_data[local_x_idx] = y_data
local_argmin[local_x_idx] = local_argmin[local_y_idx]

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)

if first_work_id:
if local_data[zero_idx] <= local_data[one_idx]:
result[group_id] = local_argmin[zero_idx]
else:
result[group_id] = local_argmin[one_idx]

# As many partial reductions as necessary are chained until only one element
# remains.
kernels_and_empty_tensors_tuples = []
n_groups = size
previous_result = dpt.empty((1,), dtype=np.int32, device=device)
while n_groups > 1:
n_groups = math.ceil(n_groups / (2 * work_group_size))
global_size = n_groups * work_group_size
kernel = partial_argmin_reduction[global_size, work_group_size]
result = dpt.empty(n_groups, dtype=np.int32, device=device)
kernels_and_empty_tensors_tuples.append((kernel, previous_result, result))
previous_result = result

def argmin_reduction(data):
for kernel, previous_result, result in kernels_and_empty_tensors_tuples:
kernel(data, previous_result, result)
return result

return argmin_reduction
17 changes: 10 additions & 7 deletions sklearn_numba_dpex/kmeans/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
make_initialize_to_zeros_3d_kernel,
make_broadcast_division_1d_2d_kernel,
make_half_l2_norm_2d_axis0_kernel,
make_sum_reduction_1d_kernel,
make_sum_reduction_2d_axis1_kernel,
)

from sklearn_numba_dpex.kmeans.kernels import (
Expand Down Expand Up @@ -315,15 +315,17 @@ def _lloyd(
dtype=compute_dtype,
)

reduce_inertia_kernel = make_sum_reduction_1d_kernel(
size=n_samples,
reduce_inertia_kernel = make_sum_reduction_2d_axis1_kernel(
size0=n_samples,
size1=None,
work_group_size=work_group_size,
device=self.device,
dtype=compute_dtype,
)

reduce_centroid_shifts_kernel = make_sum_reduction_1d_kernel(
size=n_clusters,
reduce_centroid_shifts_kernel = make_sum_reduction_2d_axis1_kernel(
size0=n_clusters,
size1=None,
work_group_size=work_group_size,
device=self.device,
dtype=compute_dtype,
Expand Down Expand Up @@ -713,8 +715,9 @@ def _get_labels_inertia(
compute_dtype,
)

reduce_inertia_kernel = make_sum_reduction_1d_kernel(
size=n_samples,
reduce_inertia_kernel = make_sum_reduction_2d_axis1_kernel(
size0=n_samples,
size1=None,
work_group_size=work_group_size,
device=self.device,
dtype=compute_dtype,
Expand Down

0 comments on commit ee0897f

Please sign in to comment.