diff --git a/device/sycl/CMakeLists.txt b/device/sycl/CMakeLists.txt index bff349cb6b..ae0cb0fdf2 100644 --- a/device/sycl/CMakeLists.txt +++ b/device/sycl/CMakeLists.txt @@ -19,8 +19,13 @@ traccc_add_library( traccc_sycl sycl TYPE SHARED "src/seeding/silicon_pixel_spacepoint_formation_algorithm_default_detector.sycl" "src/seeding/silicon_pixel_spacepoint_formation_algorithm_telescope_detector.sycl" "src/seeding/silicon_pixel_spacepoint_formation.hpp" + # Track fitting algorithm. + "include/traccc/sycl/fitting/kalman_fitting_algorithm.hpp" + "src/fitting/kalman_fitting_algorithm.cpp" + "src/fitting/kalman_fitting_algorithm_constant_field_default_detector.sycl" + "src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.sycl" + "src/fitting/fit_tracks.hpp" # header files - "include/traccc/sycl/fitting/fitting_algorithm.hpp" "include/traccc/sycl/seeding/seeding_algorithm.hpp" "include/traccc/sycl/seeding/seed_finding.hpp" "include/traccc/sycl/seeding/spacepoint_binning.hpp" @@ -30,7 +35,6 @@ traccc_add_library( traccc_sycl sycl TYPE SHARED "include/traccc/sycl/utils/make_prefix_sum_buff.hpp" # implementation files "src/clusterization/clusterization_algorithm.sycl" - "src/fitting/fitting_algorithm.sycl" "src/seeding/seed_finding.sycl" "src/seeding/seeding_algorithm.cpp" "src/seeding/spacepoint_binning.sycl" diff --git a/device/sycl/include/traccc/sycl/fitting/fitting_algorithm.hpp b/device/sycl/include/traccc/sycl/fitting/fitting_algorithm.hpp deleted file mode 100644 index 2395f990bc..0000000000 --- a/device/sycl/include/traccc/sycl/fitting/fitting_algorithm.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/** TRACCC library, part of the ACTS project (R&D line) - * - * (c) 2022 CERN for the benefit of the ACTS project - * - * Mozilla Public License Version 2.0 - */ - -#pragma once - -// SYCL library include(s). -#include "traccc/sycl/utils/queue_wrapper.hpp" - -// Project include(s). -#include "traccc/edm/track_candidate.hpp" -#include "traccc/edm/track_state.hpp" -#include "traccc/utils/algorithm.hpp" -#include "traccc/utils/memory_resource.hpp" - -// VecMem include(s). -#include -#include - -// System include(s). -#include - -namespace traccc::sycl { - -/// Fitting algorithm for a set of tracks -template -class fitting_algorithm - : public algorithm { - - public: - using algebra_type = typename fitter_t::algebra_type; - /// Configuration type - using config_type = typename fitter_t::config_type; - - /// Constructor for the fitting algorithm - /// - /// @param mr The memory resource to use - /// @param queue is a wrapper for the sycl queue for kernel invocation - fitting_algorithm(const config_type& cfg, const traccc::memory_resource& mr, - queue_wrapper queue); - - /// Run the algorithm - track_state_container_types::buffer operator()( - const typename fitter_t::detector_type::view_type& det_view, - const typename fitter_t::bfield_type& field_view, - const typename track_candidate_container_types::const_view& - track_candidates_view) const override; - - private: - /// Config object - config_type m_cfg; - /// Memory resource used by the algorithm - traccc::memory_resource m_mr; - /// Queue wrapper - mutable queue_wrapper m_queue; - /// Copy object used by the algorithm - std::unique_ptr m_copy; -}; - -} // namespace traccc::sycl diff --git a/device/sycl/include/traccc/sycl/fitting/kalman_fitting_algorithm.hpp b/device/sycl/include/traccc/sycl/fitting/kalman_fitting_algorithm.hpp new file mode 100644 index 0000000000..0ea43ec697 --- /dev/null +++ b/device/sycl/include/traccc/sycl/fitting/kalman_fitting_algorithm.hpp @@ -0,0 +1,95 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +// SYCL library include(s). +#include "traccc/sycl/utils/queue_wrapper.hpp" + +// Project include(s). +#include "traccc/edm/track_candidate.hpp" +#include "traccc/edm/track_state.hpp" +#include "traccc/fitting/fitting_config.hpp" +#include "traccc/geometry/detector.hpp" +#include "traccc/utils/algorithm.hpp" +#include "traccc/utils/memory_resource.hpp" + +// Detray include(s). +#include + +// VecMem include(s). +#include + +// System include(s). +#include + +namespace traccc::sycl { + +/// Kalman filter based track fitting algorithm +class kalman_fitting_algorithm + : public algorithm, + public algorithm { + + public: + /// Configuration type + using config_type = fitting_config; + /// Output type + using output_type = track_state_container_types::buffer; + + /// Constructor with the algorithm's configuration + /// + /// @param config The configuration object + /// + kalman_fitting_algorithm(const config_type& config, + const traccc::memory_resource& mr, + vecmem::copy& copy, queue_wrapper queue); + + /// Execute the algorithm + /// + /// @param det The (default) detector object + /// @param field The (constant) magnetic field object + /// @param track_candidates All track candidates to fit + /// + /// @return A container of the fitted track states + /// + output_type operator()(const default_detector::view& det, + const detray::bfield::const_field_t::view_t& field, + const track_candidate_container_types::const_view& + track_candidates) const override; + + /// Execute the algorithm + /// + /// @param det The (telescope) detector object + /// @param field The (constant) magnetic field object + /// @param track_candidates All track candidates to fit + /// + /// @return A container of the fitted track states + /// + output_type operator()(const telescope_detector::view& det, + const detray::bfield::const_field_t::view_t& field, + const track_candidate_container_types::const_view& + track_candidates) const override; + + private: + /// Algorithm configuration + config_type m_config; + /// Memory resource used by the algorithm + traccc::memory_resource m_mr; + /// Copy object used by the algorithm + std::reference_wrapper m_copy; + /// Queue wrapper + mutable queue_wrapper m_queue; + +}; // class kalman_fitting_algorithm + +} // namespace traccc::sycl diff --git a/device/sycl/src/fitting/fit_tracks.hpp b/device/sycl/src/fitting/fit_tracks.hpp new file mode 100644 index 0000000000..1c5e063c26 --- /dev/null +++ b/device/sycl/src/fitting/fit_tracks.hpp @@ -0,0 +1,136 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +// Local include(s). +#include "traccc/sycl/utils/calculate1DimNdRange.hpp" + +// Project include(s). +#include "traccc/edm/device/sort_key.hpp" +#include "traccc/edm/track_candidate.hpp" +#include "traccc/edm/track_state.hpp" +#include "traccc/fitting/device/fill_sort_keys.hpp" +#include "traccc/fitting/device/fit.hpp" +#include "traccc/fitting/fitting_config.hpp" +#include "traccc/utils/memory_resource.hpp" + +// VecMem include(s). +#include + +// oneDPL include(s). +#include +#include + +// SYCL include(s). +#include + +namespace traccc::sycl { +namespace kernels { + +/// Identifier for the kernel that fills the sorting keys. +struct fill_sort_keys; + +} // namespace kernels + +namespace details { + +template +track_state_container_types::buffer fit_tracks( + const typename fitter_t::detector_type::view_type& det_view, + const typename fitter_t::bfield_type& field_view, + const typename track_candidate_container_types::const_view& + track_candidates_view, + const fitting_config& config, const memory_resource& mr, vecmem::copy& copy, + cl::sycl::queue& queue) { + + // Get the number of tracks. + const track_candidate_container_types::const_device::header_vector:: + size_type n_tracks = copy.get_size(track_candidates_view.headers); + + // Get the number of the track candidates (measurements) in each track. + const std::vector + candidate_sizes = copy.get_sizes(track_candidates_view.items); + + // Create the result buffer. + track_state_container_types::buffer track_states_buffer{ + {n_tracks, mr.main}, + {candidate_sizes, mr.main, mr.host, + vecmem::data::buffer_type::resizable}}; + vecmem::copy::event_type track_states_headers_setup_event = + copy.setup(track_states_buffer.headers); + vecmem::copy::event_type track_states_items_setup_event = + copy.setup(track_states_buffer.items); + + // Return early, if there are no tracks. + if (n_tracks == 0) { + track_states_headers_setup_event->wait(); + track_states_items_setup_event->wait(); + return track_states_buffer; + } + + // Create the buffers for sorting the parameter IDs. + vecmem::data::vector_buffer keys_buffer(n_tracks, + mr.main); + vecmem::data::vector_buffer param_ids_buffer(n_tracks, + mr.main); + vecmem::copy::event_type keys_setup_event = copy.setup(keys_buffer); + vecmem::copy::event_type param_ids_setup_event = + copy.setup(param_ids_buffer); + keys_setup_event->wait(); + param_ids_setup_event->wait(); + + // The execution range for the two kernels of the function. + static constexpr unsigned int localSize = 64; + cl::sycl::nd_range<1> range = calculate1DimNdRange(n_tracks, localSize); + + // Fill the keys and param_ids buffers. + cl::sycl::event fill_keys_event = queue.submit([&](cl::sycl::handler& h) { + h.parallel_for( + range, + [track_candidates_view, keys_view = vecmem::get_data(keys_buffer), + param_ids_view = vecmem::get_data(param_ids_buffer)]( + cl::sycl::nd_item<1> item) { + device::fill_sort_keys(item.get_global_linear_id(), + track_candidates_view, keys_view, + param_ids_view); + }); + }); + + // Sort the key to get the sorted parameter ids + vecmem::device_vector keys_device(keys_buffer); + vecmem::device_vector param_ids_device(param_ids_buffer); + fill_keys_event.wait_and_throw(); + oneapi::dpl::sort_by_key(oneapi::dpl::execution::dpcpp_default, + keys_device.begin(), keys_device.end(), + param_ids_device.begin()); + + // Run the fitting, using the sorted parameter IDs. + track_state_container_types::view track_states_view = track_states_buffer; + track_states_headers_setup_event->wait(); + track_states_items_setup_event->wait(); + queue + .submit([&](cl::sycl::handler& h) { + h.parallel_for( + range, [det_view, field_view, config, track_candidates_view, + param_ids_view = vecmem::get_data(param_ids_buffer), + track_states_view](cl::sycl::nd_item<1> item) { + device::fit(item.get_global_linear_id(), det_view, + field_view, config, + track_candidates_view, param_ids_view, + track_states_view); + }); + }) + .wait_and_throw(); + + // Return the fitted tracks. + return track_states_buffer; +} + +} // namespace details +} // namespace traccc::sycl diff --git a/device/sycl/src/fitting/fitting_algorithm.sycl b/device/sycl/src/fitting/fitting_algorithm.sycl deleted file mode 100644 index fa6945e473..0000000000 --- a/device/sycl/src/fitting/fitting_algorithm.sycl +++ /dev/null @@ -1,148 +0,0 @@ -/** TRACCC library, part of the ACTS project (R&D line) - * - * (c) 2022-2024 CERN for the benefit of the ACTS project - * - * Mozilla Public License Version 2.0 - */ - -// Project include(s). -#include "../utils/get_queue.hpp" -#include "traccc/fitting/device/fill_sort_keys.hpp" -#include "traccc/fitting/device/fit.hpp" -#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" -#include "traccc/sycl/fitting/fitting_algorithm.hpp" -#include "traccc/sycl/utils/calculate1DimNdRange.hpp" - -// detray include(s). -#include "detray/core/detector_metadata.hpp" -#include "detray/detectors/bfield.hpp" -#include "detray/navigation/navigator.hpp" -#include "detray/propagator/rk_stepper.hpp" - -// DPL include(s). -#include -#include - -// System include(s). -#include - -namespace traccc::sycl { - -namespace kernels { -/// Class identifying the kernel running @c -/// traccc::device::fit -class fit; -/// Class identifying the kernel running @c -/// traccc::device::fill_sort_keys -class fill_sort_keys; -} // namespace kernels - -template -fitting_algorithm::fitting_algorithm( - const config_type& cfg, const traccc::memory_resource& mr, - queue_wrapper queue) - : m_cfg(cfg), m_mr(mr), m_queue(queue) { - - // Initialize m_copy ptr based on memory resources that were given - if (mr.host) { - m_copy = std::make_unique(queue.queue()); - } else { - m_copy = std::make_unique(); - } -} - -template -track_state_container_types::buffer fitting_algorithm::operator()( - const typename fitter_t::detector_type::view_type& det_view, - const typename fitter_t::bfield_type& field_view, - const typename track_candidate_container_types::const_view& - track_candidates_view) const { - - // Number of tracks - const track_candidate_container_types::const_device::header_vector:: - size_type n_tracks = m_copy->get_size(track_candidates_view.headers); - - // Get the sizes of the track candidates in each track - const std::vector - candidate_sizes = m_copy->get_sizes(track_candidates_view.items); - - track_state_container_types::buffer track_states_buffer{ - {n_tracks, m_mr.main}, - {candidate_sizes, m_mr.main, m_mr.host, - vecmem::data::buffer_type::resizable}}; - - vecmem::copy::event_type track_states_headers_setup_event = - m_copy->setup(track_states_buffer.headers); - vecmem::copy::event_type track_states_items_setup_event = - m_copy->setup(track_states_buffer.items); - - track_state_container_types::view track_states_view(track_states_buffer); - - // -- localSize - // The dimension of workGroup (block) is the integer multiple of WARP_SIZE - // (=32) - unsigned int localSize = 64; - - vecmem::data::vector_buffer keys_buffer(n_tracks, - m_mr.main); - vecmem::data::vector_buffer param_ids_buffer(n_tracks, - m_mr.main); - vecmem::data::vector_view keys_view(keys_buffer); - vecmem::data::vector_view param_ids_view(param_ids_buffer); - - // Sort the key to get the sorted parameter ids - vecmem::device_vector keys_device(keys_buffer); - vecmem::device_vector param_ids_device(param_ids_buffer); - - // 1 dim ND Range for the kernel - auto trackParamsNdRange = - traccc::sycl::calculate1DimNdRange(n_tracks, localSize); - - details::get_queue(m_queue) - .submit([&](::sycl::handler& h) { - h.parallel_for( - trackParamsNdRange, [track_candidates_view, keys_view, - param_ids_view](::sycl::nd_item<1> item) { - device::fill_sort_keys(item.get_global_linear_id(), - track_candidates_view, keys_view, - param_ids_view); - }); - }) - .wait_and_throw(); - - oneapi::dpl::sort_by_key(oneapi::dpl::execution::dpcpp_default, - keys_device.begin(), keys_device.end(), - param_ids_device.begin()); - - track_states_headers_setup_event->wait(); - track_states_items_setup_event->wait(); - details::get_queue(m_queue) - .submit([&](::sycl::handler& h) { - h.parallel_for( - trackParamsNdRange, - [det_view, field_view, config = m_cfg, track_candidates_view, - param_ids_view, track_states_view](::sycl::nd_item<1> item) { - device::fit(item.get_global_linear_id(), det_view, - field_view, config, - track_candidates_view, param_ids_view, - track_states_view); - }); - }) - .wait_and_throw(); - - return track_states_buffer; -} - -// Explicit template instantiation -using default_detector_type = - detray::detector; -using default_stepper_type = - detray::rk_stepper::view_t, - default_algebra, detray::constrained_step<>>; -using default_navigator_type = detray::navigator; -using default_fitter_type = - kalman_fitter; -template class fitting_algorithm; - -} // namespace traccc::sycl diff --git a/device/sycl/src/fitting/kalman_fitting_algorithm.cpp b/device/sycl/src/fitting/kalman_fitting_algorithm.cpp new file mode 100644 index 0000000000..715ddbaca0 --- /dev/null +++ b/device/sycl/src/fitting/kalman_fitting_algorithm.cpp @@ -0,0 +1,18 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +// Local include(s). +#include "traccc/sycl/fitting/kalman_fitting_algorithm.hpp" + +namespace traccc::sycl { + +kalman_fitting_algorithm::kalman_fitting_algorithm( + const config_type& config, const traccc::memory_resource& mr, + vecmem::copy& copy, queue_wrapper queue) + : m_config{config}, m_mr{mr}, m_copy{copy}, m_queue{queue} {} + +} // namespace traccc::sycl diff --git a/device/sycl/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.sycl b/device/sycl/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.sycl new file mode 100644 index 0000000000..eca547e4ce --- /dev/null +++ b/device/sycl/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.sycl @@ -0,0 +1,48 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +// Local include(s). +#include "../utils/get_queue.hpp" +#include "fit_tracks.hpp" +#include "traccc/sycl/fitting/kalman_fitting_algorithm.hpp" + +// Project include(s). +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" + +// Detray include(s). +#include +#include + +namespace traccc::sycl { +namespace kernels { + +/// Identifier for the track fitting kernel. +struct fit_tracks_constant_field_default_detector; + +} // namespace kernels + +kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()( + const default_detector::view& det, + const detray::bfield::const_field_t::view_t& field, + const track_candidate_container_types::const_view& track_candidates) const { + + // Construct the fitter type. + using stepper_type = + detray::rk_stepper>; + using navigator_type = detray::navigator; + using fitter_type = kalman_fitter; + + // Run the track fitting. + return details::fit_tracks< + fitter_type, kernels::fit_tracks_constant_field_default_detector>( + det, field, track_candidates, m_config, m_mr, m_copy.get(), + details::get_queue(m_queue)); +} + +} // namespace traccc::sycl diff --git a/device/sycl/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.sycl b/device/sycl/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.sycl new file mode 100644 index 0000000000..073209ef1c --- /dev/null +++ b/device/sycl/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.sycl @@ -0,0 +1,48 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +// Local include(s). +#include "../utils/get_queue.hpp" +#include "fit_tracks.hpp" +#include "traccc/sycl/fitting/kalman_fitting_algorithm.hpp" + +// Project include(s). +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" + +// Detray include(s). +#include +#include + +namespace traccc::sycl { +namespace kernels { + +/// Identifier for the track fitting kernel. +struct fit_tracks_constant_field_telescope_detector; + +} // namespace kernels + +kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()( + const telescope_detector::view& det, + const detray::bfield::const_field_t::view_t& field, + const track_candidate_container_types::const_view& track_candidates) const { + + // Construct the fitter type. + using stepper_type = + detray::rk_stepper>; + using navigator_type = detray::navigator; + using fitter_type = kalman_fitter; + + // Run the track fitting. + return details::fit_tracks< + fitter_type, kernels::fit_tracks_constant_field_telescope_detector>( + det, field, track_candidates, m_config, m_mr, m_copy.get(), + details::get_queue(m_queue)); +} + +} // namespace traccc::sycl