Skip to content

Commit

Permalink
De-templated the SYCL version of the Kalman track fitting.
Browse files Browse the repository at this point in the history
  • Loading branch information
krasznaa committed Nov 1, 2024
1 parent b513756 commit 1554db0
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 216 deletions.
8 changes: 6 additions & 2 deletions device/sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ traccc_add_library( traccc_sycl sycl TYPE SHARED
"src/seeding/silicon_pixel_spacepoint_formation_algorithm_default_detector.sycl"
"src/seeding/silicon_pixel_spacepoint_formation_algorithm_telescope_detector.sycl"
"src/seeding/silicon_pixel_spacepoint_formation.hpp"
# Track fitting algorithm.
"include/traccc/sycl/fitting/kalman_fitting_algorithm.hpp"
"src/fitting/kalman_fitting_algorithm.cpp"
"src/fitting/kalman_fitting_algorithm_constant_field_default_detector.sycl"
"src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.sycl"
"src/fitting/fit_tracks.hpp"
# header files
"include/traccc/sycl/fitting/fitting_algorithm.hpp"
"include/traccc/sycl/seeding/seeding_algorithm.hpp"
"include/traccc/sycl/seeding/seed_finding.hpp"
"include/traccc/sycl/seeding/spacepoint_binning.hpp"
Expand All @@ -30,7 +35,6 @@ traccc_add_library( traccc_sycl sycl TYPE SHARED
"include/traccc/sycl/utils/make_prefix_sum_buff.hpp"
# implementation files
"src/clusterization/clusterization_algorithm.sycl"
"src/fitting/fitting_algorithm.sycl"
"src/seeding/seed_finding.sycl"
"src/seeding/seeding_algorithm.cpp"
"src/seeding/spacepoint_binning.sycl"
Expand Down
66 changes: 0 additions & 66 deletions device/sycl/include/traccc/sycl/fitting/fitting_algorithm.hpp

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

// SYCL library include(s).
#include "traccc/sycl/utils/queue_wrapper.hpp"

// Project include(s).
#include "traccc/edm/track_candidate.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/fitting_config.hpp"
#include "traccc/geometry/detector.hpp"
#include "traccc/utils/algorithm.hpp"
#include "traccc/utils/memory_resource.hpp"

// Detray include(s).
#include <detray/detectors/bfield.hpp>

// VecMem include(s).
#include <vecmem/utils/copy.hpp>

// System include(s).
#include <functional>

namespace traccc::sycl {

/// Kalman filter based track fitting algorithm
class kalman_fitting_algorithm
: public algorithm<track_state_container_types::buffer(
const default_detector::view&,
const detray::bfield::const_field_t::view_t&,
const track_candidate_container_types::const_view&)>,
public algorithm<track_state_container_types::buffer(
const telescope_detector::view&,
const detray::bfield::const_field_t::view_t&,
const track_candidate_container_types::const_view&)> {

public:
/// Configuration type
using config_type = fitting_config;
/// Output type
using output_type = track_state_container_types::buffer;

/// Constructor with the algorithm's configuration
///
/// @param config The configuration object
///
kalman_fitting_algorithm(const config_type& config,
const traccc::memory_resource& mr,
vecmem::copy& copy, queue_wrapper queue);

/// Execute the algorithm
///
/// @param det The (default) detector object
/// @param field The (constant) magnetic field object
/// @param track_candidates All track candidates to fit
///
/// @return A container of the fitted track states
///
output_type operator()(const default_detector::view& det,
const detray::bfield::const_field_t::view_t& field,
const track_candidate_container_types::const_view&
track_candidates) const override;

/// Execute the algorithm
///
/// @param det The (telescope) detector object
/// @param field The (constant) magnetic field object
/// @param track_candidates All track candidates to fit
///
/// @return A container of the fitted track states
///
output_type operator()(const telescope_detector::view& det,
const detray::bfield::const_field_t::view_t& field,
const track_candidate_container_types::const_view&
track_candidates) const override;

private:
/// Algorithm configuration
config_type m_config;
/// Memory resource used by the algorithm
traccc::memory_resource m_mr;
/// Copy object used by the algorithm
std::reference_wrapper<vecmem::copy> m_copy;
/// Queue wrapper
mutable queue_wrapper m_queue;

}; // class kalman_fitting_algorithm

} // namespace traccc::sycl
136 changes: 136 additions & 0 deletions device/sycl/src/fitting/fit_tracks.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

// Local include(s).
#include "traccc/sycl/utils/calculate1DimNdRange.hpp"

// Project include(s).
#include "traccc/edm/device/sort_key.hpp"
#include "traccc/edm/track_candidate.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/device/fill_sort_keys.hpp"
#include "traccc/fitting/device/fit.hpp"
#include "traccc/fitting/fitting_config.hpp"
#include "traccc/utils/memory_resource.hpp"

// VecMem include(s).
#include <vecmem/utils/copy.hpp>

// oneDPL include(s).
#include <oneapi/dpl/algorithm>
#include <oneapi/dpl/execution>

// SYCL include(s).
#include <CL/sycl.hpp>

namespace traccc::sycl {
namespace kernels {

/// Identifier for the kernel that fills the sorting keys.
struct fill_sort_keys;

} // namespace kernels

namespace details {

template <typename fitter_t, typename fit_kernel_t>
track_state_container_types::buffer fit_tracks(
const typename fitter_t::detector_type::view_type& det_view,
const typename fitter_t::bfield_type& field_view,
const typename track_candidate_container_types::const_view&
track_candidates_view,
const fitting_config& config, const memory_resource& mr, vecmem::copy& copy,
cl::sycl::queue& queue) {

// Get the number of tracks.
const track_candidate_container_types::const_device::header_vector::
size_type n_tracks = copy.get_size(track_candidates_view.headers);

// Get the number of the track candidates (measurements) in each track.
const std::vector<track_candidate_container_types::const_device::
item_vector::value_type::size_type>
candidate_sizes = copy.get_sizes(track_candidates_view.items);

// Create the result buffer.
track_state_container_types::buffer track_states_buffer{
{n_tracks, mr.main},
{candidate_sizes, mr.main, mr.host,
vecmem::data::buffer_type::resizable}};
vecmem::copy::event_type track_states_headers_setup_event =
copy.setup(track_states_buffer.headers);
vecmem::copy::event_type track_states_items_setup_event =
copy.setup(track_states_buffer.items);

// Return early, if there are no tracks.
if (n_tracks == 0) {
track_states_headers_setup_event->wait();
track_states_items_setup_event->wait();
return track_states_buffer;
}

// Create the buffers for sorting the parameter IDs.
vecmem::data::vector_buffer<device::sort_key> keys_buffer(n_tracks,
mr.main);
vecmem::data::vector_buffer<unsigned int> param_ids_buffer(n_tracks,
mr.main);
vecmem::copy::event_type keys_setup_event = copy.setup(keys_buffer);
vecmem::copy::event_type param_ids_setup_event =
copy.setup(param_ids_buffer);
keys_setup_event->wait();
param_ids_setup_event->wait();

// The execution range for the two kernels of the function.
static constexpr unsigned int localSize = 64;
cl::sycl::nd_range<1> range = calculate1DimNdRange(n_tracks, localSize);

// Fill the keys and param_ids buffers.
cl::sycl::event fill_keys_event = queue.submit([&](cl::sycl::handler& h) {
h.parallel_for<kernels::fill_sort_keys>(
range,
[track_candidates_view, keys_view = vecmem::get_data(keys_buffer),
param_ids_view = vecmem::get_data(param_ids_buffer)](
cl::sycl::nd_item<1> item) {
device::fill_sort_keys(item.get_global_linear_id(),
track_candidates_view, keys_view,
param_ids_view);
});
});

// Sort the key to get the sorted parameter ids
vecmem::device_vector<device::sort_key> keys_device(keys_buffer);
vecmem::device_vector<unsigned int> param_ids_device(param_ids_buffer);
fill_keys_event.wait_and_throw();
oneapi::dpl::sort_by_key(oneapi::dpl::execution::dpcpp_default,
keys_device.begin(), keys_device.end(),
param_ids_device.begin());

// Run the fitting, using the sorted parameter IDs.
track_state_container_types::view track_states_view = track_states_buffer;
track_states_headers_setup_event->wait();
track_states_items_setup_event->wait();
queue
.submit([&](cl::sycl::handler& h) {
h.parallel_for<fit_kernel_t>(
range, [det_view, field_view, config, track_candidates_view,
param_ids_view = vecmem::get_data(param_ids_buffer),
track_states_view](cl::sycl::nd_item<1> item) {
device::fit<fitter_t>(item.get_global_linear_id(), det_view,
field_view, config,
track_candidates_view, param_ids_view,
track_states_view);
});
})
.wait_and_throw();

// Return the fitted tracks.
return track_states_buffer;
}

} // namespace details
} // namespace traccc::sycl
Loading

0 comments on commit 1554db0

Please sign in to comment.