-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
De-templated the SYCL version of the Kalman track fitting.
- Loading branch information
Showing
8 changed files
with
351 additions
and
216 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 0 additions & 66 deletions
66
device/sycl/include/traccc/sycl/fitting/fitting_algorithm.hpp
This file was deleted.
Oops, something went wrong.
95 changes: 95 additions & 0 deletions
95
device/sycl/include/traccc/sycl/fitting/kalman_fitting_algorithm.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/** TRACCC library, part of the ACTS project (R&D line) | ||
* | ||
* (c) 2022-2024 CERN for the benefit of the ACTS project | ||
* | ||
* Mozilla Public License Version 2.0 | ||
*/ | ||
|
||
#pragma once | ||
|
||
// SYCL library include(s). | ||
#include "traccc/sycl/utils/queue_wrapper.hpp" | ||
|
||
// Project include(s). | ||
#include "traccc/edm/track_candidate.hpp" | ||
#include "traccc/edm/track_state.hpp" | ||
#include "traccc/fitting/fitting_config.hpp" | ||
#include "traccc/geometry/detector.hpp" | ||
#include "traccc/utils/algorithm.hpp" | ||
#include "traccc/utils/memory_resource.hpp" | ||
|
||
// Detray include(s). | ||
#include <detray/detectors/bfield.hpp> | ||
|
||
// VecMem include(s). | ||
#include <vecmem/utils/copy.hpp> | ||
|
||
// System include(s). | ||
#include <functional> | ||
|
||
namespace traccc::sycl { | ||
|
||
/// Kalman filter based track fitting algorithm | ||
class kalman_fitting_algorithm | ||
: public algorithm<track_state_container_types::buffer( | ||
const default_detector::view&, | ||
const detray::bfield::const_field_t::view_t&, | ||
const track_candidate_container_types::const_view&)>, | ||
public algorithm<track_state_container_types::buffer( | ||
const telescope_detector::view&, | ||
const detray::bfield::const_field_t::view_t&, | ||
const track_candidate_container_types::const_view&)> { | ||
|
||
public: | ||
/// Configuration type | ||
using config_type = fitting_config; | ||
/// Output type | ||
using output_type = track_state_container_types::buffer; | ||
|
||
/// Constructor with the algorithm's configuration | ||
/// | ||
/// @param config The configuration object | ||
/// | ||
kalman_fitting_algorithm(const config_type& config, | ||
const traccc::memory_resource& mr, | ||
vecmem::copy& copy, queue_wrapper queue); | ||
|
||
/// Execute the algorithm | ||
/// | ||
/// @param det The (default) detector object | ||
/// @param field The (constant) magnetic field object | ||
/// @param track_candidates All track candidates to fit | ||
/// | ||
/// @return A container of the fitted track states | ||
/// | ||
output_type operator()(const default_detector::view& det, | ||
const detray::bfield::const_field_t::view_t& field, | ||
const track_candidate_container_types::const_view& | ||
track_candidates) const override; | ||
|
||
/// Execute the algorithm | ||
/// | ||
/// @param det The (telescope) detector object | ||
/// @param field The (constant) magnetic field object | ||
/// @param track_candidates All track candidates to fit | ||
/// | ||
/// @return A container of the fitted track states | ||
/// | ||
output_type operator()(const telescope_detector::view& det, | ||
const detray::bfield::const_field_t::view_t& field, | ||
const track_candidate_container_types::const_view& | ||
track_candidates) const override; | ||
|
||
private: | ||
/// Algorithm configuration | ||
config_type m_config; | ||
/// Memory resource used by the algorithm | ||
traccc::memory_resource m_mr; | ||
/// Copy object used by the algorithm | ||
std::reference_wrapper<vecmem::copy> m_copy; | ||
/// Queue wrapper | ||
mutable queue_wrapper m_queue; | ||
|
||
}; // class kalman_fitting_algorithm | ||
|
||
} // namespace traccc::sycl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
/** TRACCC library, part of the ACTS project (R&D line) | ||
* | ||
* (c) 2022-2024 CERN for the benefit of the ACTS project | ||
* | ||
* Mozilla Public License Version 2.0 | ||
*/ | ||
|
||
#pragma once | ||
|
||
// Local include(s). | ||
#include "traccc/sycl/utils/calculate1DimNdRange.hpp" | ||
|
||
// Project include(s). | ||
#include "traccc/edm/device/sort_key.hpp" | ||
#include "traccc/edm/track_candidate.hpp" | ||
#include "traccc/edm/track_state.hpp" | ||
#include "traccc/fitting/device/fill_sort_keys.hpp" | ||
#include "traccc/fitting/device/fit.hpp" | ||
#include "traccc/fitting/fitting_config.hpp" | ||
#include "traccc/utils/memory_resource.hpp" | ||
|
||
// VecMem include(s). | ||
#include <vecmem/utils/copy.hpp> | ||
|
||
// oneDPL include(s). | ||
#include <oneapi/dpl/algorithm> | ||
#include <oneapi/dpl/execution> | ||
|
||
// SYCL include(s). | ||
#include <CL/sycl.hpp> | ||
|
||
namespace traccc::sycl { | ||
namespace kernels { | ||
|
||
/// Identifier for the kernel that fills the sorting keys. | ||
struct fill_sort_keys; | ||
|
||
} // namespace kernels | ||
|
||
namespace details { | ||
|
||
template <typename fitter_t, typename fit_kernel_t> | ||
track_state_container_types::buffer fit_tracks( | ||
const typename fitter_t::detector_type::view_type& det_view, | ||
const typename fitter_t::bfield_type& field_view, | ||
const typename track_candidate_container_types::const_view& | ||
track_candidates_view, | ||
const fitting_config& config, const memory_resource& mr, vecmem::copy& copy, | ||
cl::sycl::queue& queue) { | ||
|
||
// Get the number of tracks. | ||
const track_candidate_container_types::const_device::header_vector:: | ||
size_type n_tracks = copy.get_size(track_candidates_view.headers); | ||
|
||
// Get the number of the track candidates (measurements) in each track. | ||
const std::vector<track_candidate_container_types::const_device:: | ||
item_vector::value_type::size_type> | ||
candidate_sizes = copy.get_sizes(track_candidates_view.items); | ||
|
||
// Create the result buffer. | ||
track_state_container_types::buffer track_states_buffer{ | ||
{n_tracks, mr.main}, | ||
{candidate_sizes, mr.main, mr.host, | ||
vecmem::data::buffer_type::resizable}}; | ||
vecmem::copy::event_type track_states_headers_setup_event = | ||
copy.setup(track_states_buffer.headers); | ||
vecmem::copy::event_type track_states_items_setup_event = | ||
copy.setup(track_states_buffer.items); | ||
|
||
// Return early, if there are no tracks. | ||
if (n_tracks == 0) { | ||
track_states_headers_setup_event->wait(); | ||
track_states_items_setup_event->wait(); | ||
return track_states_buffer; | ||
} | ||
|
||
// Create the buffers for sorting the parameter IDs. | ||
vecmem::data::vector_buffer<device::sort_key> keys_buffer(n_tracks, | ||
mr.main); | ||
vecmem::data::vector_buffer<unsigned int> param_ids_buffer(n_tracks, | ||
mr.main); | ||
vecmem::copy::event_type keys_setup_event = copy.setup(keys_buffer); | ||
vecmem::copy::event_type param_ids_setup_event = | ||
copy.setup(param_ids_buffer); | ||
keys_setup_event->wait(); | ||
param_ids_setup_event->wait(); | ||
|
||
// The execution range for the two kernels of the function. | ||
static constexpr unsigned int localSize = 64; | ||
cl::sycl::nd_range<1> range = calculate1DimNdRange(n_tracks, localSize); | ||
|
||
// Fill the keys and param_ids buffers. | ||
cl::sycl::event fill_keys_event = queue.submit([&](cl::sycl::handler& h) { | ||
h.parallel_for<kernels::fill_sort_keys>( | ||
range, | ||
[track_candidates_view, keys_view = vecmem::get_data(keys_buffer), | ||
param_ids_view = vecmem::get_data(param_ids_buffer)]( | ||
cl::sycl::nd_item<1> item) { | ||
device::fill_sort_keys(item.get_global_linear_id(), | ||
track_candidates_view, keys_view, | ||
param_ids_view); | ||
}); | ||
}); | ||
|
||
// Sort the key to get the sorted parameter ids | ||
vecmem::device_vector<device::sort_key> keys_device(keys_buffer); | ||
vecmem::device_vector<unsigned int> param_ids_device(param_ids_buffer); | ||
fill_keys_event.wait_and_throw(); | ||
oneapi::dpl::sort_by_key(oneapi::dpl::execution::dpcpp_default, | ||
keys_device.begin(), keys_device.end(), | ||
param_ids_device.begin()); | ||
|
||
// Run the fitting, using the sorted parameter IDs. | ||
track_state_container_types::view track_states_view = track_states_buffer; | ||
track_states_headers_setup_event->wait(); | ||
track_states_items_setup_event->wait(); | ||
queue | ||
.submit([&](cl::sycl::handler& h) { | ||
h.parallel_for<fit_kernel_t>( | ||
range, [det_view, field_view, config, track_candidates_view, | ||
param_ids_view = vecmem::get_data(param_ids_buffer), | ||
track_states_view](cl::sycl::nd_item<1> item) { | ||
device::fit<fitter_t>(item.get_global_linear_id(), det_view, | ||
field_view, config, | ||
track_candidates_view, param_ids_view, | ||
track_states_view); | ||
}); | ||
}) | ||
.wait_and_throw(); | ||
|
||
// Return the fitted tracks. | ||
return track_states_buffer; | ||
} | ||
|
||
} // namespace details | ||
} // namespace traccc::sycl |
Oops, something went wrong.