Skip to content

Commit

Permalink
Implement PyOpenCL kernel framework.
Browse files Browse the repository at this point in the history
* Add helper functions for PyOpenCL kernel loading, memory allocation
  and event tracking.

* Add kernels to match the current OpenMP and JAX ones.

* Expand tests to include this infrastructure.

This work might not be merged, but it has served as a useful test
to see what work is involved in the develop and debug cycle when
using PyOpenCL as the backend.
  • Loading branch information
tskisner committed Nov 10, 2024
1 parent defcf34 commit 2f57562
Show file tree
Hide file tree
Showing 83 changed files with 4,530 additions and 324 deletions.
1 change: 1 addition & 0 deletions src/toast/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ add_subdirectory(io)
add_subdirectory(accelerator)
add_subdirectory(tests)
add_subdirectory(jax)
add_subdirectory(opencl)
add_subdirectory(ops)
add_subdirectory(templates)
add_subdirectory(scripts)
Expand Down
8 changes: 8 additions & 0 deletions src/toast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@
* Values "0", "false", or "no" will disable runtime support for hybrid GPU pipelines.
* Requires TOAST_GPU_OPENMP or TOAST_GPU_JAX to be enabled.
TOAST_OPENCL=<value>
* Values "1", "true", or "yes" will enable runtime support for pyopencl.
* Requires pyopencl to be available / importable.
TOAST_OPENCL_DEFAULT=<value>
* Default OpenCL device type, where supported values are "CPU", "GPU",
and "OCLGRIND".
OMP_NUM_THREADS=<integer>
* Toast uses OpenMP threading in several places and the concurrency is set by the
usual environment variable.
Expand Down
47 changes: 33 additions & 14 deletions src/toast/_libtoast/ops_pixels_healpix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ void pixels_healpix_nest_inner(
int64_t n_samp,
int64_t idet,
uint8_t mask,
bool use_flags
bool use_flags,
bool compute_submaps
) {
const double zaxis[3] = {0.0, 0.0, 1.0};
int32_t p_indx = pixel_index[idet];
Expand All @@ -618,8 +619,10 @@ void pixels_healpix_nest_inner(
if (use_flags && ((flags[isamp] & mask) != 0)) {
pixels[poff] = -1;
} else {
sub_map = (int64_t)(pixels[poff] / n_pix_submap);
hsub[sub_map] = 1;
if (compute_submaps) {
sub_map = (int64_t)(pixels[poff] / n_pix_submap);
hsub[sub_map] = 1;
}
}

return;
Expand All @@ -639,7 +642,9 @@ void pixels_healpix_ring_inner(
int64_t n_samp,
int64_t idet,
uint8_t mask,
bool use_flags) {
bool use_flags,
bool compute_submaps
) {
const double zaxis[3] = {0.0, 0.0, 1.0};
int32_t p_indx = pixel_index[idet];
int32_t q_indx = quat_index[idet];
Expand All @@ -658,8 +663,10 @@ void pixels_healpix_ring_inner(
if (use_flags && ((flags[isamp] & mask) != 0)) {
pixels[poff] = -1;
} else {
sub_map = (int64_t)(pixels[poff] / n_pix_submap);
hsub[sub_map] = 1;
if (compute_submaps) {
sub_map = (int64_t)(pixels[poff] / n_pix_submap);
hsub[sub_map] = 1;
}
}

return;
Expand Down Expand Up @@ -1163,6 +1170,7 @@ void init_ops_pixels_healpix(py::module & m) {
int64_t n_pix_submap,
int64_t nside,
bool nest,
bool compute_submaps,
bool use_accel
) {
auto & omgr = OmpManager::get();
Expand Down Expand Up @@ -1195,10 +1203,14 @@ void init_ops_pixels_healpix(py::module & m) {
);
int64_t n_view = temp_shape[0];

// Optionally compute the hit submaps
uint8_t * raw_hsub = extract_buffer <uint8_t> (
hit_submaps, "hit_submaps", 1, temp_shape, {-1}
);
int64_t n_submap = temp_shape[0];
if (! compute_submaps) {
raw_hsub = omgr.null_ptr <uint8_t> ();
}

// Optionally use flags
bool use_flags = true;
Expand All @@ -1225,6 +1237,7 @@ void init_ops_pixels_healpix(py::module & m) {
int64_t * dev_pixels = omgr.device_ptr(raw_pixels);
Interval * dev_intervals = omgr.device_ptr(raw_intervals);
uint8_t * dev_flags = omgr.device_ptr(raw_flags);
uint8_t * dev_hsub = omgr.device_ptr(raw_hsub);

// Make sure the lookup table exists on device
size_t utab_bytes = 0x100 * sizeof(int64_t);
Expand Down Expand Up @@ -1258,9 +1271,9 @@ void init_ops_pixels_healpix(py::module & m) {
n_det, \
n_samp, \
shared_flag_mask, \
compute_submaps, \
use_flags \
) \
map(tofrom : raw_hsub[0 : n_submap])
)
{
if (nest) {
# pragma omp target teams distribute parallel for collapse(3) \
Expand All @@ -1269,6 +1282,7 @@ void init_ops_pixels_healpix(py::module & m) {
dev_pixels, \
dev_quats, \
dev_flags, \
dev_hsub, \
dev_intervals, \
dev_utab \
)
Expand All @@ -1293,14 +1307,15 @@ void init_ops_pixels_healpix(py::module & m) {
raw_pixel_index,
dev_quats,
dev_flags,
raw_hsub,
dev_hsub,
dev_pixels,
n_pix_submap,
adjusted_isamp,
n_samp,
idet,
shared_flag_mask,
use_flags
use_flags,
compute_submaps
);
}
}
Expand All @@ -1312,6 +1327,7 @@ void init_ops_pixels_healpix(py::module & m) {
dev_pixels, \
dev_quats, \
dev_flags, \
dev_hsub, \
dev_intervals, \
dev_utab \
)
Expand All @@ -1335,14 +1351,15 @@ void init_ops_pixels_healpix(py::module & m) {
raw_pixel_index,
dev_quats,
dev_flags,
raw_hsub,
dev_hsub,
dev_pixels,
n_pix_submap,
adjusted_isamp,
n_samp,
idet,
shared_flag_mask,
use_flags
use_flags,
compute_submaps
);
}
}
Expand Down Expand Up @@ -1376,7 +1393,8 @@ void init_ops_pixels_healpix(py::module & m) {
n_samp,
idet,
shared_flag_mask,
use_flags
use_flags,
compute_submaps
);
}
}
Expand Down Expand Up @@ -1404,7 +1422,8 @@ void init_ops_pixels_healpix(py::module & m) {
n_samp,
idet,
shared_flag_mask,
use_flags
use_flags,
compute_submaps
);
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/toast/accelerator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
accel_data_update_device,
accel_data_update_host,
accel_enabled,
accel_wait,
accel_get_device,
use_accel_jax,
use_accel_omp,
use_accel_opencl,
use_hybrid_pipelines,
)
from .kernel_registry import ImplementationType, kernel
Loading

0 comments on commit 2f57562

Please sign in to comment.