Skip to content

Commit

Permalink
Split the CUDA CKF into different TUs
Browse files Browse the repository at this point in the history
This commit splits the monstrously large CUDA track finding translation
unit up into smaller ones, one for each of the kernels. This should
speed up compilation times and decrease memory usage.

Also groups the payloads for each of the functions into convenient
structs, so we don't need to pass 20+ arguments for some of the kernel
calls.

Does not change the functionality of the code.
  • Loading branch information
stephenswat committed Oct 16, 2024
1 parent 4c26db6 commit bbeb4f7
Show file tree
Hide file tree
Showing 39 changed files with 826 additions and 470 deletions.
2 changes: 1 addition & 1 deletion cmake/traccc-compiler-options-cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ traccc_add_flag( CMAKE_CUDA_FLAGS "--expt-relaxed-constexpr" )

# Make CUDA generate debug symbols for the device code as well in a debug
# build.
traccc_add_flag( CMAKE_CUDA_FLAGS_DEBUG "-G --keep -src-in-ptx" )
traccc_add_flag( CMAKE_CUDA_FLAGS_DEBUG "-G -src-in-ptx" )

# Ensure that line information is embedded in debugging builds so that
# profilers have access to line data.
Expand Down
3 changes: 2 additions & 1 deletion core/include/traccc/finding/ckf_aborter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "detray/definitions/detail/qualifiers.hpp"
#include "detray/propagator/base_actor.hpp"
#include "detray/propagator/base_stepper.hpp"
#include "traccc/definitions/primitives.hpp"

// System include(s)
#include <limits>
Expand Down Expand Up @@ -51,4 +52,4 @@ struct ckf_aborter : detray::actor {
}
};

} // namespace traccc
} // namespace traccc
24 changes: 13 additions & 11 deletions device/common/include/traccc/finding/device/apply_interaction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,31 @@
#pragma once

// Project include(s).
#include "detray/navigation/navigator.hpp"
#include "detray/propagator/actors/pointwise_material_interactor.hpp"
#include "traccc/definitions/qualifiers.hpp"
#include "traccc/finding/finding_config.hpp"
#include "traccc/utils/particle.hpp"

namespace traccc::device {
template <typename detector_t>
struct apply_interaction_payload {
typename detector_t::view_type det_data;
const int n_params;
bound_track_parameters_collection_types::view params_view;
vecmem::data::vector_view<const unsigned int> params_liveness_view;
};

/// Function applying the Pre material interaction to tracks spawned by bound
/// track parameters
///
/// @param[in] globalIndex The index of the current thread
/// @param[in] cfg Track finding config object
/// @param[in] det_data Detector view object
/// @param[in] n_params The number of parameters (or tracks)
/// @param[out] params_view Collection of output bound track_parameters
/// @param[in] params_liveness_view Vector of parameter liveness indicators
///
/// @param[inout] payload The function call payload
template <typename detector_t>
TRACCC_DEVICE inline void apply_interaction(
std::size_t globalIndex, const finding_config& cfg,
typename detector_t::view_type det_data, const int n_params,
bound_track_parameters_collection_types::view params_view,
vecmem::data::vector_view<const unsigned int> params_liveness_view);

const apply_interaction_payload<detector_t>& payload);
} // namespace traccc::device

// Include the implementation.
#include "traccc/finding/device/impl/apply_interaction.ipp"
#include "./impl/apply_interaction.ipp"
40 changes: 19 additions & 21 deletions device/common/include/traccc/finding/device/build_tracks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,22 @@

// Project include(s).
#include "traccc/definitions/qualifiers.hpp"
#include "traccc/edm/measurement.hpp"
#include "traccc/edm/track_candidate.hpp"
#include "traccc/edm/track_parameters.hpp"
#include "traccc/finding/candidate_link.hpp"

namespace traccc::device {
struct build_tracks_payload {
measurement_collection_types::const_view measurements_view;
bound_track_parameters_collection_types::const_view seeds_view;
vecmem::data::jagged_vector_view<const candidate_link> links_view;
vecmem::data::vector_view<const typename candidate_link::link_index_type>
tips_view;
track_candidate_container_types::view track_candidates_view;
vecmem::data::vector_view<unsigned int> valid_indices_view;
unsigned int& n_valid_tracks;
};

/// Function for building full tracks from the link container:
/// The full tracks are built using the link container and tip link container.
Expand All @@ -19,28 +33,12 @@ namespace traccc::device {
///
/// @param[in] globalIndex The index of the current thread
/// @param[in] cfg Track finding config object
/// @param[in] measurements_view Measurements container view
/// @param[in] seeds_view Seed container view
/// @param[in] link_view Link container view
/// @param[in] param_to_link_view Container for param index -> link index
/// @param[in] tips_view Tip link container view
/// @param[out] track_candidates_view Track candidate container view
/// @param[out] valid_indices_view Valid indices meeting criteria
/// @param[out] n_valid_tracks The number of valid tracks meeting criteria

/// @param[inout] payload The function call payload
template <typename config_t>
TRACCC_DEVICE inline void build_tracks(
std::size_t globalIndex, const config_t cfg,
measurement_collection_types::const_view measurements_view,
bound_track_parameters_collection_types::const_view seeds_view,
vecmem::data::jagged_vector_view<const candidate_link> links_view,
vecmem::data::vector_view<const typename candidate_link::link_index_type>
tips_view,
track_candidate_container_types::view track_candidates_view,
vecmem::data::vector_view<unsigned int> valid_indices_view,
unsigned int& n_valid_tracks);
TRACCC_DEVICE inline void build_tracks(std::size_t globalIndex,
const config_t cfg,
const build_tracks_payload& payload);

} // namespace traccc::device

// Include the implementation.
#include "traccc/finding/device/impl/build_tracks.ipp"
#include "./impl/build_tracks.ipp"
19 changes: 8 additions & 11 deletions device/common/include/traccc/finding/device/fill_sort_keys.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,18 @@
#include "traccc/edm/track_candidate.hpp"

namespace traccc::device {
struct fill_sort_keys_payload {
bound_track_parameters_collection_types::const_view params_view;
vecmem::data::vector_view<device::sort_key> keys_view;
vecmem::data::vector_view<unsigned int> ids_view;
};

/// Function used for fill key container
///
/// @param[in] globalIndex The index of the current thread
/// @param[in] params_view The input parameters
/// @param[out] keys_view The key values
/// @param[out] ids_view The param ids
///
/// @param[inout] payload The function call payload
TRACCC_HOST_DEVICE inline void fill_sort_keys(
std::size_t globalIndex,
bound_track_parameters_collection_types::const_view params_view,
vecmem::data::vector_view<device::sort_key> keys_view,
vecmem::data::vector_view<unsigned int> ids_view);

std::size_t globalIndex, const fill_sort_keys_payload& payload);
} // namespace traccc::device

// Include the implementation.
#include "traccc/finding/device/impl/fill_sort_keys.ipp"
#include "./impl/fill_sort_keys.ipp"
70 changes: 32 additions & 38 deletions device/common/include/traccc/finding/device/find_tracks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,38 @@
#include "traccc/device/concepts/thread_id.hpp"
#include "traccc/edm/measurement.hpp"
#include "traccc/edm/track_parameters.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/finding/candidate_link.hpp"
#include "traccc/finding/finding_config.hpp"
#include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp"

// Thrust include(s)
#include <thrust/binary_search.h>

namespace traccc::device {
template <typename detector_t>
struct find_tracks_payload {
typename detector_t::view_type det_data;
measurement_collection_types::const_view measurements_view;
bound_track_parameters_collection_types::const_view in_params_view;
vecmem::data::vector_view<const unsigned int> in_params_liveness_view;
const unsigned int n_in_params;
vecmem::data::vector_view<const detray::geometry::barcode> barcodes_view;
vecmem::data::vector_view<const unsigned int> upper_bounds_view;
vecmem::data::vector_view<const candidate_link> prev_links_view;
const unsigned int step;
const unsigned int n_max_candidates;
bound_track_parameters_collection_types::view out_params_view;
vecmem::data::vector_view<unsigned int> out_params_liveness_view;
vecmem::data::vector_view<candidate_link> links_view;
unsigned int* n_total_candidates;
};

struct find_tracks_shared_payload {
unsigned int* shared_num_candidates;
std::pair<unsigned int, unsigned int>* shared_candidates;
unsigned int& shared_candidates_size;
};

/// Function for combinatorial finding.
/// If the chi2 of the measurement < chi2_max, its measurement index and the
Expand All @@ -27,47 +54,14 @@ namespace traccc::device {
/// @param[in] thread_id A thread identifier object
/// @param[in] barrier A block-wide barrier
/// @param[in] cfg Track finding config object
/// @param[in] det_data Detector view object
/// @param[in] measurements_view Measurements container view
/// @param[in] in_params_view Input parameters
/// @param[in] n_in_params The number of input params
/// @param[in] barcodes_view View of a measurement -> barcode map
/// @param[in] upper_bounds_view Upper bounds of measurements unique w.r.t
/// barcode
/// @param[in] prev_links_view link container from the previous step
/// @param[in] prev_param_to_link_view param_to_link container from the
/// previous step
/// @param[in] step Step index
/// @param[in] n_max_candidates Number of maximum candidates
/// @param[out] out_params_view Output parameters
/// @param[out] links_view link container for the current step
/// @param[out] n_total_candidates The number of total candidates for the
/// current step
/// @param shared_num_candidates Shared memory scratch space
/// @param shared_candidates Shared memory scratch space
/// @param shared_candidates_size Shared memory scratch space
///
/// @param[inout] payload The global memory payload
/// @param[inout] shared_payload The shared memory payload
template <concepts::thread_id1 thread_id_t, concepts::barrier barrier_t,
typename detector_t, typename config_t>
TRACCC_DEVICE inline void find_tracks(
thread_id_t& thread_id, barrier_t& barrier, const config_t cfg,
typename detector_t::view_type det_data,
measurement_collection_types::const_view measurements_view,
bound_track_parameters_collection_types::const_view in_params_view,
vecmem::data::vector_view<const unsigned int> in_params_liveness_view,
const unsigned int n_in_params,
vecmem::data::vector_view<const detray::geometry::barcode> barcodes_view,
vecmem::data::vector_view<const unsigned int> upper_bounds_view,
vecmem::data::vector_view<const candidate_link> prev_links_view,
const unsigned int step, const unsigned int& n_max_candidates,
bound_track_parameters_collection_types::view out_params_view,
vecmem::data::vector_view<unsigned int> out_params_liveness_view,
vecmem::data::vector_view<candidate_link> links_view,
unsigned int& n_total_candidates, unsigned int* shared_num_candidates,
std::pair<unsigned int, unsigned int>* shared_candidates,
unsigned int& shared_candidates_size);

const find_tracks_payload<detector_t>& payload,
const find_tracks_shared_payload& shared_payload);
} // namespace traccc::device

// Include the implementation.
#include "traccc/finding/device/impl/find_tracks.ipp"
#include "./impl/find_tracks.ipp"
Original file line number Diff line number Diff line change
@@ -1,42 +1,39 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2023-2024 CERN for the benefit of the ACTS project
* (c) 2023 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

// Project include(s).
#include "traccc/definitions/math.hpp"
#include "detray/navigation/navigator.hpp"
#include "detray/propagator/actors/pointwise_material_interactor.hpp"
#include "traccc/definitions/qualifiers.hpp"
#include "traccc/finding/finding_config.hpp"
#include "traccc/utils/particle.hpp"

// Detray include(s).
#include "detray/geometry/tracking_surface.hpp"
#include "vecmem/containers/device_vector.hpp"

namespace traccc::device {

template <typename detector_t>
TRACCC_DEVICE inline void apply_interaction(
std::size_t globalIndex, const finding_config& cfg,
typename detector_t::view_type det_data, const int n_params,
bound_track_parameters_collection_types::view params_view,
vecmem::data::vector_view<const unsigned int> params_liveness_view) {
const apply_interaction_payload<detector_t>& payload) {

// Type definitions
using algebra_type = typename detector_t::algebra_type;
using interactor_type = detray::pointwise_material_interactor<algebra_type>;

// Detector
detector_t det(det_data);
detector_t det(payload.det_data);

// in param
bound_track_parameters_collection_types::device params(params_view);
bound_track_parameters_collection_types::device params(payload.params_view);
vecmem::device_vector<const unsigned int> params_liveness(
params_liveness_view);
payload.params_liveness_view);

if (globalIndex >= n_params) {
if (globalIndex >= payload.n_params) {
return;
}

Expand All @@ -57,5 +54,4 @@ TRACCC_DEVICE inline void apply_interaction(
static_cast<int>(detray::navigation::direction::e_forward), sf);
}
}

} // namespace traccc::device
38 changes: 21 additions & 17 deletions device/common/include/traccc/finding/device/impl/build_tracks.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,37 @@

#pragma once

// Project include(s).
#include "traccc/definitions/qualifiers.hpp"
#include "traccc/edm/measurement.hpp"
#include "traccc/edm/track_candidate.hpp"
#include "traccc/edm/track_parameters.hpp"
#include "traccc/finding/candidate_link.hpp"

namespace traccc::device {

template <typename config_t>
TRACCC_DEVICE inline void build_tracks(
std::size_t globalIndex, const config_t cfg,
measurement_collection_types::const_view measurements_view,
bound_track_parameters_collection_types::const_view seeds_view,
vecmem::data::jagged_vector_view<const candidate_link> links_view,
vecmem::data::vector_view<const typename candidate_link::link_index_type>
tips_view,
track_candidate_container_types::view track_candidates_view,
vecmem::data::vector_view<unsigned int> valid_indices_view,
unsigned int& n_valid_tracks) {
TRACCC_DEVICE inline void build_tracks(std::size_t globalIndex,
const config_t cfg,
const build_tracks_payload& payload) {

measurement_collection_types::const_device measurements(measurements_view);
measurement_collection_types::const_device measurements(
payload.measurements_view);

bound_track_parameters_collection_types::const_device seeds(seeds_view);
bound_track_parameters_collection_types::const_device seeds(
payload.seeds_view);

vecmem::jagged_device_vector<const candidate_link> links(links_view);
vecmem::jagged_device_vector<const candidate_link> links(
payload.links_view);

vecmem::device_vector<const typename candidate_link::link_index_type> tips(
tips_view);
payload.tips_view);

track_candidate_container_types::device track_candidates(
track_candidates_view);
payload.track_candidates_view);

vecmem::device_vector<unsigned int> valid_indices(valid_indices_view);
vecmem::device_vector<unsigned int> valid_indices(
payload.valid_indices_view);

if (globalIndex >= tips.size()) {
return;
Expand Down Expand Up @@ -107,7 +111,7 @@ TRACCC_DEVICE inline void build_tracks(
n_cands <= cfg.max_track_candidates_per_track) {

vecmem::device_atomic_ref<unsigned int> num_valid_tracks(
n_valid_tracks);
payload.n_valid_tracks);

const unsigned int pos = num_valid_tracks.fetch_add(1);
valid_indices[pos] = globalIndex;
Expand Down
Loading

0 comments on commit bbeb4f7

Please sign in to comment.