Skip to content

Commit

Permalink
More improvements to event dependency handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
tskisner committed Nov 6, 2024
1 parent 32a8bbd commit 6402411
Show file tree
Hide file tree
Showing 20 changed files with 317 additions and 118 deletions.
5 changes: 3 additions & 2 deletions src/toast/accelerator/accel.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,9 @@ def accel_data_update_host(data, name="None"):
return data.to_host()
elif use_accel_opencl:
ocl = OpenCL()
ev = ocl.mem_update_host(data, name=name)
return ev
evs = list()
evs.append(ocl.mem_update_host(data, name=name, async_=True))
return evs
else:
log = Logger.get()
log.warning("Accelerator support not enabled, not updating host")
Expand Down
2 changes: 1 addition & 1 deletion src/toast/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def accel_update_device(self, names):
else:
log.verbose(f"Calling Data update_device for '{key}'")
ev = val.accel_update_device()
print(f"DATA extend with global:{key} = {ev}")
# print(f"DATA extend with global:{key} = {ev}")
events[first_ob].extend(ev)
else:
msg = f"Data accel_update_device: '{key}' ({type(val)}) "
Expand Down
6 changes: 5 additions & 1 deletion src/toast/observation_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def _accel_exists(self):

def _accel_create(self, zero_out=False):
if use_accel_omp or use_accel_opencl:
print(f"DD create {self._accel_name}")
self._raw = accel_data_create(
self._raw, self._accel_name, zero_out=zero_out
)
Expand All @@ -572,6 +573,7 @@ def _accel_update_device(self):
self._data = accel_data_update_device(self._data)
elif use_accel_opencl:
dev_data = accel_data_update_device(self._raw, self._accel_name)
print(f"DD update device {self._accel_name}, evs={dev_data.events}")
return dev_data.events

def _accel_update_host(self):
Expand All @@ -580,7 +582,9 @@ def _accel_update_host(self):
elif use_accel_jax:
self._data = accel_data_update_host(self._data)
elif use_accel_opencl:
return accel_data_update_host(self._raw, self._accel_name)
evs = accel_data_update_host(self._raw, self._accel_name)
print(f"DD update host {self._accel_name}, evs={evs}")
return evs

def _accel_delete(self):
if use_accel_omp or use_accel_opencl:
Expand Down
9 changes: 8 additions & 1 deletion src/toast/opencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,12 @@
"""OpenCL tools.
"""

from .utils import have_opencl, find_source, get_kernel_deps, add_kernel_dep
from .utils import (
have_opencl,
find_source,
get_kernel_deps,
add_kernel_deps,
replace_kernel_deps,
clear_kernel_deps,
)
from .platform import OpenCL
90 changes: 71 additions & 19 deletions src/toast/opencl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,12 @@ def aligned_to_dtype(aligned):
raise ValueError(msg)


def get_kernel_deps(state, obs_name, clear=True):
def get_kernel_deps(state, obs_name):
"""Extract kernel wait_for events for the current observation.
Args:
state (dict): The state dictionary
obs_name (str): The observation name
clear (bool): If True, clear the observation events before
returning them.
Returns:
(list): The list of events to wait on.
Expand All @@ -119,26 +117,83 @@ def get_kernel_deps(state, obs_name, clear=True):
# No dependencies for this observation
# print(f"GET {obs_name}: obs_name not in state", flush=True)
return list()
if not isinstance(state[obs_name], list):
msg = f"kernel state for obs {obs_name} is not a list"
# Return events
return state[obs_name]


def clear_kernel_deps(state, obs_name):
"""Clear kernel events for a given observation.
This should be done **after** the events are completed.
Args:
state (dict): The state dictionary
obs_name (str): The observation name
Returns:
None
"""
if obs_name is None:
msg = "Observation name cannot be None"
raise RuntimeError(msg)
if state is None:
# No dependencies
return
if not isinstance(state, dict):
msg = "kernel state should be a dictionary keyed on observation name"
raise RuntimeError(msg)
# Return events and clear the list
wait_for = [x for x in state[obs_name] if x is not None]
# print(f"GET {obs_name}: got {wait_for} and clearing", flush=True)
if clear:
if obs_name not in state:
# No dependencies for this observation
return
# Clear
state[obs_name].clear()


def replace_kernel_deps(state, obs_name, events):
"""Clear the events for a given observation and replace.
The event list for the specified observation is created if needed.
Args:
state (dict): The state dictionary
obs_name (str): The observation name
events (Event, list): pyopencl event or list of events.
Returns:
None
"""
if obs_name is None:
msg = "Observation name cannot be None"
raise RuntimeError(msg)
if state is None:
msg = "State dictionary cannot be None"
raise RuntimeError(msg)
if not isinstance(state, dict):
msg = "kernel state should be a dictionary keyed on observation name"
raise RuntimeError(msg)
if obs_name in state:
state[obs_name].clear()
return wait_for
else:
state[obs_name] = list()
if events is None:
return
if isinstance(events, list):
state[obs_name].extend(events)
else:
state[obs_name].append(events)


def add_kernel_dep(state, obs_name, event):
def add_kernel_deps(state, obs_name, events):
"""Append event(s) to the current observation state.
The event list for the specified observation is created if needed.
Args:
state (dict): The state dictionary
obs_name (str): The observation name
event (Event): pyopencl event or list of events.
events (Event, list): pyopencl event or list of events.
Returns:
None
Expand All @@ -151,13 +206,10 @@ def add_kernel_dep(state, obs_name, event):
msg = "State dictionary cannot be None"
raise RuntimeError(msg)
if obs_name not in state:
# Create the entry for this observation
# print(f"SET {obs_name}: create obs_name list", flush=True)
state[obs_name] = list()
# print(f"SET {obs_name}: add event {event}", flush=True)
if event is None:
if events is None:
return
if isinstance(event, list):
state[obs_name].extend(event)
if isinstance(events, list):
state[obs_name].extend(events)
else:
state[obs_name].append(event)
state[obs_name].append(events)
3 changes: 3 additions & 0 deletions src/toast/ops/mapmaker_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ def _exec(self, data, detectors=None, **kwargs):
log.verbose(" BinMap running pipeline")
pipe_out = accum.apply(data, detectors=detectors)

good_pix = data[self.binned].data != 0
print(f"Binned zmap = {data[self.binned].data[good_pix]}")

# print("Binned zmap = ", data[self.binned].data)

# Optionally, store the noise-weighted map
Expand Down
15 changes: 15 additions & 0 deletions src/toast/ops/mapmaker_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def _exec(self, data, detectors=None, **kwargs):
self.binning.det_data_units = self.det_data_units
self.binning.apply(data, detectors=detectors)

good_pix = data[self.binning.binned].data != 0
print(f"RHS binned = {data[self.binning.binned].data[good_pix]}")

log.debug_rank("MapMaker RHS binned map finished in", comm=comm, timer=timer)

# Build a pipeline for the projection and template matrix application.
Expand Down Expand Up @@ -220,6 +223,10 @@ def _exec(self, data, detectors=None, **kwargs):
"MapMaker RHS begin cleanup temporary detector data", comm=comm
)

for tkey in data[self.template_matrix.amplitudes].keys():
good_amps = data[self.template_matrix.amplitudes][tkey].local != 0
print(f"RHS {tkey}: {data[self.template_matrix.amplitudes][tkey].local[good_amps]}")

# Clean up our temp buffer
delete_temp = Delete(detdata=[det_temp])
delete_temp.apply(data)
Expand Down Expand Up @@ -370,6 +377,10 @@ def _exec(self, data, detectors=None, **kwargs):
timer.start()
log.debug_rank("MapMaker LHS begin project amplitudes and binning", comm=comm)

# for tkey in data[self.template_matrix.amplitudes].keys():
# good_amps = data[self.template_matrix.amplitudes][tkey].local != 0
# print(f"LHS IN {tkey}: {data[self.template_matrix.amplitudes][tkey].local[good_amps]}")

self.template_matrix.transpose = False
self.template_matrix.det_data = self.det_temp
self.template_matrix.det_data_units = self.det_data_units
Expand Down Expand Up @@ -520,6 +531,10 @@ def _exec(self, data, detectors=None, **kwargs):

proj_pipe.apply(data, detectors=detectors)

# for tkey in data[self.out].keys():
# good_amps = data[self.out][tkey].local != 0
# print(f"LHS OUT {tkey}: {data[self.out][tkey].local[good_amps]}")

log.debug_rank(
"MapMaker LHS map scan and amplitude accumulate finished in",
comm=comm,
Expand Down
6 changes: 6 additions & 0 deletions src/toast/ops/mapmaker_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,12 @@ def _get_pixel_covariance(self, solve_pixels, solve_weights):

solver_cov.apply(self._data, detectors=self._detectors)

good_hits = self._data[self.solver_hits_name].data != 0
print(f"Solve hits = {self._data[self.solver_hits_name].data[good_hits]}")

good_pix = self._data[self.solver_cov_name].data != 0
print(f"Solve covariance = {self._data[self.solver_cov_name].data[good_pix]}")

self._memreport.prefix = "After constructing covariance and hits"
self._memreport.apply(self._data)

Expand Down
2 changes: 1 addition & 1 deletion src/toast/ops/mapmaker_utils/kernels_opencl.cl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ __kernel void build_noise_weighted(

scaled_data = det_scale[d_indx] * det_data[doff];
for (long i = 0; i < nnz; i++) {
//atomic_add_double(&(zmap[zoff + i]), scaled_data * weights[woff + i]);
atomic_add_double(&(zmap[zoff + i]), scaled_data * weights[woff + i]);
}
}

Expand Down
32 changes: 21 additions & 11 deletions src/toast/ops/mapmaker_utils/kernels_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
import pyopencl as cl

from ...accelerator import ImplementationType, kernel
from ...opencl import find_source, OpenCL, add_kernel_dep, get_kernel_deps
from ...opencl import (
find_source,
OpenCL,
add_kernel_deps,
get_kernel_deps,
replace_kernel_deps,
clear_kernel_deps,
)


@kernel(impl=ImplementationType.OPENCL, name="build_noise_weighted")
Expand Down Expand Up @@ -45,7 +52,7 @@ def build_noise_weighted_opencl(
use_det_flags = np.uint8(0)

ocl = OpenCL()

queue = ocl.queue()
devtype = ocl.default_device_type

kernel = ocl.get_or_build_kernel(
Expand All @@ -67,23 +74,26 @@ def build_noise_weighted_opencl(
# Allocate temporary device arrays

dev_pixels_index = ocl.mem_to_device(pixels_index, device_type=devtype, async_=True)
add_kernel_dep(state, obs_name, dev_pixels_index.events)
add_kernel_deps(state, obs_name, dev_pixels_index.events)

dev_weight_index = ocl.mem_to_device(weight_index, device_type=devtype, async_=True)
add_kernel_dep(state, obs_name, dev_weight_index.events)
add_kernel_deps(state, obs_name, dev_weight_index.events)

dev_det_data_index = ocl.mem_to_device(det_data_index, device_type=devtype, async_=True)
add_kernel_dep(state, obs_name, dev_det_data_index.events)
dev_det_data_index = ocl.mem_to_device(
det_data_index, device_type=devtype, async_=True
)
add_kernel_deps(state, obs_name, dev_det_data_index.events)

dev_flag_index = ocl.mem_to_device(flag_index, device_type=devtype, async_=True)
add_kernel_dep(state, obs_name, dev_flag_index.events)
add_kernel_deps(state, obs_name, dev_flag_index.events)

dev_det_scale = ocl.mem_to_device(det_scale, device_type=devtype, async_=True)
add_kernel_dep(state, obs_name, dev_det_scale.events)
add_kernel_deps(state, obs_name, dev_det_scale.events)

# All of the events that our kernels depend on
wait_for = get_kernel_deps(state, obs_name)
#print(f"BLDNSEW: {obs_name} got wait_for = {wait_for}", flush=True)
print(f"BLDNSEW: {obs_name} got wait_for = {wait_for}", flush=True)
print(f"BLDNSEW: {obs_name} pixels={dev_pixels}, weights={dev_weights}, zmap={dev_zmap}", flush=True)

n_det = len(det_data_index)
n_samp = weights.shape[1]
Expand Down Expand Up @@ -121,12 +131,12 @@ def build_noise_weighted_opencl(
wait_for=wait_for,
)
wait_for = [ev]
add_kernel_dep(state, obs_name, wait_for)
clear_kernel_deps(state, obs_name)
add_kernel_deps(state, obs_name, wait_for)

# Free temporaries
ocl.mem_remove(pixels_index, device_type=devtype)
ocl.mem_remove(weight_index, device_type=devtype)
ocl.mem_remove(det_data_index, device_type=devtype)
ocl.mem_remove(flag_index, device_type=devtype)
ocl.mem_remove(det_scale, device_type=devtype)

2 changes: 2 additions & 0 deletions src/toast/ops/mapmaker_utils/mapmaker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs):
else:
shared_flag_data = np.zeros(1, dtype=np.uint8)

print(f"BLD {ob.name} {self.pixels}={ob.detdata[self.pixels].data}, {self.weights}={ob.detdata[self.weights].data}, {self.det_data}={ob.detdata[self.det_data].data}, {self.view}={ob.intervals[self.view].data}")

build_noise_weighted(
zmap.distribution.global_submap_to_local,
zmap.data,
Expand Down
16 changes: 12 additions & 4 deletions src/toast/ops/noise_weight/kernels_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import pyopencl as cl

from ...accelerator import ImplementationType, kernel
from ...opencl import find_source, OpenCL, add_kernel_dep, get_kernel_deps
from ...opencl import (
find_source,
OpenCL,
add_kernel_deps,
get_kernel_deps,
clear_kernel_deps,
)


@kernel(impl=ImplementationType.OPENCL, name="noise_weight")
Expand All @@ -24,6 +30,7 @@ def noise_weight_opencl(
program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl")

ocl = OpenCL()
queue = ocl.queue()
devtype = ocl.default_device_type

# Get our kernel
Expand All @@ -41,11 +48,11 @@ def noise_weight_opencl(
dev_det_data_index = ocl.mem_to_device(
det_data_index, device_type=devtype, async_=True
)
add_kernel_dep(state, obs_name, dev_det_data_index.events)
add_kernel_deps(state, obs_name, dev_det_data_index.events)
dev_det_weights = ocl.mem_to_device(
detector_weights, device_type=devtype, async_=True
)
add_kernel_dep(state, obs_name, dev_det_weights.events)
add_kernel_deps(state, obs_name, dev_det_weights.events)

# All of the events that our kernels depend on
wait_for = get_kernel_deps(state, obs_name)
Expand All @@ -69,7 +76,8 @@ def noise_weight_opencl(
wait_for=wait_for,
)
wait_for = [ev]
add_kernel_dep(state, obs_name, wait_for)
clear_kernel_deps(state, obs_name)
add_kernel_deps(state, obs_name, wait_for)

# Free temporaries
ocl.mem_remove(det_data_index, device_type=devtype)
Expand Down
Loading

0 comments on commit 6402411

Please sign in to comment.