diff --git a/examples/run/common/throughput_mt.ipp b/examples/run/common/throughput_mt.ipp index c3dd3ed2b2..3b25ee36c7 100644 --- a/examples/run/common/throughput_mt.ipp +++ b/examples/run/common/throughput_mt.ipp @@ -14,17 +14,25 @@ #include "traccc/options/program_options.hpp" #include "traccc/options/threading.hpp" #include "traccc/options/throughput.hpp" +#include "traccc/options/track_finding.hpp" +#include "traccc/options/track_propagation.hpp" #include "traccc/options/track_seeding.hpp" // I/O include(s). #include "traccc/io/demonstrator_edm.hpp" #include "traccc/io/read.hpp" +#include "traccc/io/read_geometry.hpp" +#include "traccc/io/utils.hpp" // Performance measurement include(s). #include "traccc/performance/throughput.hpp" #include "traccc/performance/timer.hpp" #include "traccc/performance/timing_info.hpp" +// Detray include(s). +#include "detray/core/detector.hpp" +#include "detray/io/frontend/detector_reader.hpp" + // VecMem include(s). #include @@ -53,12 +61,14 @@ int throughput_mt(std::string_view description, int argc, char* argv[], opts::input_data input_opts; opts::clusterization clusterization_opts; opts::track_seeding seeding_opts; + opts::track_finding finding_opts; + opts::track_propagation propagation_opts; opts::throughput throughput_opts; opts::threading threading_opts; opts::program_options program_opts{ description, {detector_opts, input_opts, clusterization_opts, seeding_opts, - throughput_opts, threading_opts}, + finding_opts, propagation_opts, throughput_opts, threading_opts}, argc, argv}; @@ -75,6 +85,34 @@ int throughput_mt(std::string_view description, int argc, char* argv[], // Memory resource to use in the test. HOST_MR uncached_host_mr; + // Read in the geometry. + auto [surface_transforms, barcode_map] = traccc::io::read_geometry( + detector_opts.detector_file, + (detector_opts.use_detray_detector ? traccc::data_format::json + : traccc::data_format::csv)); + using detector_type = detray::detector; + detector_type detector{uncached_host_mr}; + if (detector_opts.use_detray_detector) { + // Set up the detector reader configuration. + detray::io::detector_reader_config cfg; + cfg.add_file(traccc::io::data_directory() + + detector_opts.detector_file); + if (detector_opts.material_file.empty() == false) { + cfg.add_file(traccc::io::data_directory() + + detector_opts.material_file); + } + if (detector_opts.grid_file.empty() == false) { + cfg.add_file(traccc::io::data_directory() + + detector_opts.grid_file); + } + + // Read the detector. + auto det = + detray::io::read_detector(uncached_host_mr, cfg); + detector = std::move(det.first); + } + // Read in all input events into memory. demonstrator_input input(&uncached_host_mr); @@ -85,9 +123,12 @@ int throughput_mt(std::string_view description, int argc, char* argv[], input.push_back(demonstrator_input::value_type(&uncached_host_mr)); } // Read event data into input vector - io::read(input, input_opts.events, input_opts.directory, - detector_opts.detector_file, detector_opts.digitization_file, - input_opts.format); + io::read( + input, input_opts.events, input_opts.directory, + detector_opts.detector_file, detector_opts.digitization_file, + input_opts.format, + (detector_opts.use_detray_detector ? traccc::data_format::json + : traccc::data_format::csv)); } // Set up cached memory resources on top of the host memory resource @@ -95,6 +136,25 @@ int throughput_mt(std::string_view description, int argc, char* argv[], std::vector > cached_host_mrs{threading_opts.threads + 1}; + // Algorithm configuration(s). + typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg; + finding_cfg.min_track_candidates_per_track = + finding_opts.track_candidates_range[0]; + finding_cfg.max_track_candidates_per_track = + finding_opts.track_candidates_range[1]; + finding_cfg.min_step_length_for_next_surface = + finding_opts.min_step_length_for_next_surface; + finding_cfg.max_step_counts_for_next_surface = + finding_opts.max_step_counts_for_next_surface; + finding_cfg.chi2_max = finding_opts.chi2_max; + finding_cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed; + finding_cfg.max_num_skipping_per_cand = + finding_opts.max_num_skipping_per_cand; + propagation_opts.setup(finding_cfg.propagation); + + typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg; + propagation_opts.setup(fitting_cfg.propagation); + // Set up the full-chain algorithm(s). One for each thread. std::vector algs; algs.reserve(threading_opts.threads + 1); @@ -108,11 +168,15 @@ int throughput_mt(std::string_view description, int argc, char* argv[], ? static_cast( *(cached_host_mrs.at(i))) : static_cast(uncached_host_mr); - algs.push_back({alg_host_mr, - clusterization_opts.target_cells_per_partition, - seeding_opts.seedfinder, - {seeding_opts.seedfinder}, - seeding_opts.seedfilter}); + algs.push_back( + {alg_host_mr, + clusterization_opts.target_cells_per_partition, + seeding_opts.seedfinder, + {seeding_opts.seedfinder}, + seeding_opts.seedfilter, + finding_cfg, + fitting_cfg, + (detector_opts.use_detray_detector ? &detector : nullptr)}); } // Seed the random number generator. diff --git a/examples/run/common/throughput_st.ipp b/examples/run/common/throughput_st.ipp index 237907692a..5853db6c51 100644 --- a/examples/run/common/throughput_st.ipp +++ b/examples/run/common/throughput_st.ipp @@ -13,17 +13,25 @@ #include "traccc/options/input_data.hpp" #include "traccc/options/program_options.hpp" #include "traccc/options/throughput.hpp" +#include "traccc/options/track_finding.hpp" +#include "traccc/options/track_propagation.hpp" #include "traccc/options/track_seeding.hpp" // I/O include(s). #include "traccc/io/demonstrator_edm.hpp" #include "traccc/io/read.hpp" +#include "traccc/io/read_geometry.hpp" +#include "traccc/io/utils.hpp" // Performance measurement include(s). #include "traccc/performance/throughput.hpp" #include "traccc/performance/timer.hpp" #include "traccc/performance/timing_info.hpp" +// Detray include(s). +#include "detray/core/detector.hpp" +#include "detray/io/frontend/detector_reader.hpp" + // VecMem include(s). #include @@ -44,11 +52,13 @@ int throughput_st(std::string_view description, int argc, char* argv[], opts::input_data input_opts; opts::clusterization clusterization_opts; opts::track_seeding seeding_opts; + opts::track_finding finding_opts; + opts::track_propagation propagation_opts; opts::throughput throughput_opts; opts::program_options program_opts{ description, {detector_opts, input_opts, clusterization_opts, seeding_opts, - throughput_opts}, + finding_opts, propagation_opts, throughput_opts}, argc, argv}; @@ -60,6 +70,34 @@ int throughput_st(std::string_view description, int argc, char* argv[], std::unique_ptr cached_host_mr = std::make_unique(uncached_host_mr); + // Read in the geometry. + auto [surface_transforms, barcode_map] = traccc::io::read_geometry( + detector_opts.detector_file, + (detector_opts.use_detray_detector ? traccc::data_format::json + : traccc::data_format::csv)); + using detector_type = detray::detector; + detector_type detector{uncached_host_mr}; + if (detector_opts.use_detray_detector) { + // Set up the detector reader configuration. + detray::io::detector_reader_config cfg; + cfg.add_file(traccc::io::data_directory() + + detector_opts.detector_file); + if (detector_opts.material_file.empty() == false) { + cfg.add_file(traccc::io::data_directory() + + detector_opts.material_file); + } + if (detector_opts.grid_file.empty() == false) { + cfg.add_file(traccc::io::data_directory() + + detector_opts.grid_file); + } + + // Read the detector. + auto det = + detray::io::read_detector(uncached_host_mr, cfg); + detector = std::move(det.first); + } + vecmem::memory_resource& alg_host_mr = use_host_caching ? static_cast(*cached_host_mr) @@ -75,17 +113,40 @@ int throughput_st(std::string_view description, int argc, char* argv[], input.push_back(demonstrator_input::value_type(&uncached_host_mr)); } // Read event data into input vector - io::read(input, input_opts.events, input_opts.directory, - detector_opts.detector_file, detector_opts.digitization_file, - input_opts.format); + io::read( + input, input_opts.events, input_opts.directory, + detector_opts.detector_file, detector_opts.digitization_file, + input_opts.format, + (detector_opts.use_detray_detector ? traccc::data_format::json + : traccc::data_format::csv)); } + // Algorithm configuration(s). + typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg; + finding_cfg.min_track_candidates_per_track = + finding_opts.track_candidates_range[0]; + finding_cfg.max_track_candidates_per_track = + finding_opts.track_candidates_range[1]; + finding_cfg.min_step_length_for_next_surface = + finding_opts.min_step_length_for_next_surface; + finding_cfg.max_step_counts_for_next_surface = + finding_opts.max_step_counts_for_next_surface; + finding_cfg.chi2_max = finding_opts.chi2_max; + finding_cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed; + finding_cfg.max_num_skipping_per_cand = + finding_opts.max_num_skipping_per_cand; + propagation_opts.setup(finding_cfg.propagation); + + typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg; + propagation_opts.setup(fitting_cfg.propagation); + // Set up the full-chain algorithm. std::unique_ptr alg = std::make_unique( alg_host_mr, clusterization_opts.target_cells_per_partition, seeding_opts.seedfinder, spacepoint_grid_config{seeding_opts.seedfinder}, - seeding_opts.seedfilter); + seeding_opts.seedfilter, finding_cfg, fitting_cfg, + (detector_opts.use_detray_detector ? &detector : nullptr)); // Seed the random number generator. std::srand(std::time(0)); diff --git a/examples/run/cpu/CMakeLists.txt b/examples/run/cpu/CMakeLists.txt index 6255513b88..5169a65ca2 100644 --- a/examples/run/cpu/CMakeLists.txt +++ b/examples/run/cpu/CMakeLists.txt @@ -33,12 +33,14 @@ add_library( traccc_examples_cpu STATIC "full_chain_algorithm.hpp" "full_chain_algorithm.cpp" ) target_link_libraries( traccc_examples_cpu - PUBLIC vecmem::core traccc::core ) + PUBLIC vecmem::core detray::core detray::utils traccc::core ) traccc_add_executable( throughput_st "throughput_st.cpp" - LINK_LIBRARIES vecmem::core traccc::core traccc::io - traccc::performance traccc::options traccc_examples_cpu ) + LINK_LIBRARIES vecmem::core detray::utils detray::io + traccc::core traccc::io traccc::performance + traccc::options traccc_examples_cpu ) traccc_add_executable( throughput_mt "throughput_mt.cpp" - LINK_LIBRARIES TBB::tbb vecmem::core traccc::core traccc::io - traccc::performance traccc::options traccc_examples_cpu ) + LINK_LIBRARIES TBB::tbb vecmem::core detray::utils detray::io + traccc::core traccc::io traccc::performance + traccc::options traccc_examples_cpu ) diff --git a/examples/run/cpu/full_chain_algorithm.cpp b/examples/run/cpu/full_chain_algorithm.cpp index 6e7c923a75..4149492536 100644 --- a/examples/run/cpu/full_chain_algorithm.cpp +++ b/examples/run/cpu/full_chain_algorithm.cpp @@ -14,26 +14,56 @@ full_chain_algorithm::full_chain_algorithm( vecmem::memory_resource& mr, unsigned int, const seedfinder_config& finder_config, const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config) - : m_clusterization(mr), + const seedfilter_config& filter_config, + const finding_algorithm::config_type& finding_config, + const fitting_algorithm::config_type& fitting_config, + detector_type* detector) + : m_field_vec{0.f, 0.f, finder_config.bFieldInZ}, + m_field(detray::bfield::create_const_field(m_field_vec)), + m_detector(detector), + m_clusterization(mr), m_spacepoint_formation(mr), m_seeding(finder_config, grid_config, filter_config, mr), m_track_parameter_estimation(mr), + m_finding(finding_config), + m_fitting(fitting_config), m_finder_config(finder_config), m_grid_config(grid_config), - m_filter_config(filter_config) {} + m_filter_config(filter_config), + m_finding_config(finding_config), + m_fitting_config(fitting_config) {} full_chain_algorithm::output_type full_chain_algorithm::operator()( const cell_collection_types::host& cells, const cell_module_collection_types::host& modules) const { + // Run the clusterization. + const host::clusterization_algorithm::output_type measurements = + m_clusterization(vecmem::get_data(cells), vecmem::get_data(modules)); + + // Run the seed-finding. const host::spacepoint_formation_algorithm::output_type spacepoints = - m_spacepoint_formation( - vecmem::get_data(m_clusterization(vecmem::get_data(cells), - vecmem::get_data(modules))), - vecmem::get_data(modules)); - return m_track_parameter_estimation(spacepoints, m_seeding(spacepoints), - {0.f, 0.f, m_finder_config.bFieldInZ}); + m_spacepoint_formation(vecmem::get_data(measurements), + vecmem::get_data(modules)); + const track_params_estimation::output_type track_params = + m_track_parameter_estimation(spacepoints, m_seeding(spacepoints), + m_field_vec); + + // If we have a Detray detector, run the track finding and fitting. + if (m_detector != nullptr) { + + // Return the final container, after track finding and fitting. + return m_fitting( + *m_detector, m_field, + m_finding(*m_detector, m_field, measurements, track_params)); + + } + // If not, just return an empty object. + else { + + // Return an empty object. + return {}; + } } } // namespace traccc diff --git a/examples/run/cpu/full_chain_algorithm.hpp b/examples/run/cpu/full_chain_algorithm.hpp index 8a7cf353f1..bf5b6f86c4 100644 --- a/examples/run/cpu/full_chain_algorithm.hpp +++ b/examples/run/cpu/full_chain_algorithm.hpp @@ -11,10 +11,21 @@ #include "traccc/clusterization/clusterization_algorithm.hpp" #include "traccc/clusterization/spacepoint_formation_algorithm.hpp" #include "traccc/edm/cell.hpp" +#include "traccc/edm/track_state.hpp" +#include "traccc/finding/finding_algorithm.hpp" +#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" #include "traccc/seeding/seeding_algorithm.hpp" #include "traccc/seeding/track_params_estimation.hpp" #include "traccc/utils/algorithm.hpp" +// Detray include(s). +#include "detray/core/detector.hpp" +#include "detray/detectors/bfield.hpp" +#include "detray/navigation/navigator.hpp" +#include "detray/propagator/propagator.hpp" +#include "detray/propagator/rk_stepper.hpp" + // VecMem include(s). #include @@ -24,12 +35,35 @@ namespace traccc { /// /// At least as much as is implemented in the project at any given moment. /// -class full_chain_algorithm - : public algorithm { +class full_chain_algorithm : public algorithm { public: + /// @name Type declaration(s) + /// @{ + + /// Detector type used during track finding and fitting + using detector_type = detray::detector; + + /// Stepper type used by the track finding and fitting algorithms + using stepper_type = + detray::rk_stepper>; + /// Navigator type used by the track finding and fitting algorithms + using navigator_type = detray::navigator; + + /// Track finding algorithm type + using finding_algorithm = + traccc::finding_algorithm; + /// Track fitting algorithm type + using fitting_algorithm = traccc::fitting_algorithm< + traccc::kalman_fitter>; + + /// @} + /// Algorithm constructor /// /// @param mr The memory resource to use for the intermediate and result @@ -37,11 +71,13 @@ class full_chain_algorithm /// @param dummy This is not used anywhere. Allows templating CPU/Device /// algorithm. /// - full_chain_algorithm(vecmem::memory_resource& mr, unsigned int dummy, const seedfinder_config& finder_config, const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config); + const seedfilter_config& filter_config, + const finding_algorithm::config_type& finding_config, + const fitting_algorithm::config_type& fitting_config, + detector_type* detector); /// Reconstruct track parameters in the entire detector /// @@ -53,6 +89,14 @@ class full_chain_algorithm const cell_module_collection_types::host& modules) const override; private: + /// Constant B field for the (seed) track parameter estimation + traccc::vector3 m_field_vec; + /// Constant B field for the track finding and fitting + detray::bfield::const_field_t m_field; + + /// Detector + detector_type* m_detector; + /// @name Sub-algorithms used by this full-chain algorithm /// @{ @@ -65,11 +109,28 @@ class full_chain_algorithm /// Track parameter estimation algorithm track_params_estimation m_track_parameter_estimation; - /// Configs + /// Track finding algorithm + finding_algorithm m_finding; + /// Track fitting algorithm + fitting_algorithm m_fitting; + + /// @} + + /// @name Algorithm configurations + /// @{ + + /// Configuration for the seed finding seedfinder_config m_finder_config; + /// Configuration for the spacepoint grid formation spacepoint_grid_config m_grid_config; + /// Configuration for the seed filtering seedfilter_config m_filter_config; + /// Configuration for the track finding + finding_algorithm::config_type m_finding_config; + /// Configuration for the track fitting + fitting_algorithm::config_type m_fitting_config; + /// @} }; // class full_chain_algorithm diff --git a/examples/run/cuda/CMakeLists.txt b/examples/run/cuda/CMakeLists.txt index bd7c0e95a9..ee5bfa05f6 100644 --- a/examples/run/cuda/CMakeLists.txt +++ b/examples/run/cuda/CMakeLists.txt @@ -33,15 +33,17 @@ add_library( traccc_examples_cuda STATIC "full_chain_algorithm.hpp" "full_chain_algorithm.cpp" ) target_link_libraries( traccc_examples_cuda - PUBLIC CUDA::cudart vecmem::core vecmem::cuda traccc::core - traccc::device_common traccc::cuda ) + PUBLIC CUDA::cudart vecmem::core vecmem::cuda detray::core detray::utils + traccc::core traccc::device_common traccc::cuda ) traccc_add_executable( throughput_st_cuda "throughput_st.cpp" - LINK_LIBRARIES vecmem::core vecmem::cuda traccc::io traccc::performance + LINK_LIBRARIES vecmem::core vecmem::cuda detray::utils detray::io + traccc::io traccc::performance traccc::core traccc::device_common traccc::cuda traccc::options traccc_examples_cuda ) traccc_add_executable( throughput_mt_cuda "throughput_mt.cpp" - LINK_LIBRARIES TBB::tbb vecmem::core vecmem::cuda traccc::io traccc::performance + LINK_LIBRARIES TBB::tbb vecmem::core vecmem::cuda detray::utils detray::io + traccc::io traccc::performance traccc::core traccc::device_common traccc::cuda traccc::options traccc_examples_cuda ) diff --git a/examples/run/cuda/full_chain_algorithm.cpp b/examples/run/cuda/full_chain_algorithm.cpp index d43706eff6..7c0063f1a4 100644 --- a/examples/run/cuda/full_chain_algorithm.cpp +++ b/examples/run/cuda/full_chain_algorithm.cpp @@ -32,13 +32,19 @@ full_chain_algorithm::full_chain_algorithm( const unsigned short target_cells_per_partition, const seedfinder_config& finder_config, const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config) + const seedfilter_config& filter_config, + const finding_algorithm::config_type& finding_config, + const fitting_algorithm::config_type& fitting_config, + host_detector_type* detector) : m_host_mr(host_mr), m_stream(), m_device_mr(), m_cached_device_mr( std::make_unique(m_device_mr)), m_copy(m_stream.cudaStream()), + m_field_vec{0.f, 0.f, finder_config.bFieldInZ}, + m_field(detray::bfield::create_const_field(m_field_vec)), + m_detector(detector), m_target_cells_per_partition(target_cells_per_partition), m_clusterization(memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, m_stream, m_target_cells_per_partition), @@ -50,9 +56,18 @@ full_chain_algorithm::full_chain_algorithm( m_stream), m_track_parameter_estimation( memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, m_stream), + m_finding(finding_config, + memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, + m_stream), + m_fitting(fitting_config, + memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, + m_stream), + m_result_copy(memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy), m_finder_config(finder_config), m_grid_config(grid_config), - m_filter_config(filter_config) { + m_filter_config(filter_config), + m_finding_config(finding_config), + m_fitting_config(fitting_config) { // Tell the user what device is being used. int device = 0; @@ -62,6 +77,13 @@ full_chain_algorithm::full_chain_algorithm( std::cout << "Using CUDA device: " << props.name << " [id: " << device << ", bus: " << props.pciBusID << ", device: " << props.pciDeviceID << "]" << std::endl; + + // Copy the detector to the device. + if (m_detector != nullptr) { + m_device_detector = detray::get_buffer(detray::get_data(*m_detector), + m_device_mr, m_copy); + m_device_detector_view = detray::get_data(m_device_detector); + } } full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent) @@ -71,6 +93,9 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent) m_cached_device_mr( std::make_unique(m_device_mr)), m_copy(m_stream.cudaStream()), + m_field_vec(parent.m_field_vec), + m_field(parent.m_field), + m_detector(parent.m_detector), m_target_cells_per_partition(parent.m_target_cells_per_partition), m_clusterization(memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, m_stream, m_target_cells_per_partition), @@ -82,9 +107,26 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent) memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, m_stream), m_track_parameter_estimation( memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, m_stream), + m_finding(parent.m_finding_config, + memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, + m_stream), + m_fitting(parent.m_fitting_config, + memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy, + m_stream), + m_result_copy(memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy), m_finder_config(parent.m_finder_config), m_grid_config(parent.m_grid_config), - m_filter_config(parent.m_filter_config) {} + m_filter_config(parent.m_filter_config), + m_finding_config(parent.m_finding_config), + m_fitting_config(parent.m_fitting_config) { + + // Copy the detector to the device. + if (m_detector != nullptr) { + m_device_detector = detray::get_buffer(detray::get_data(*m_detector), + m_device_mr, m_copy); + m_device_detector_view = detray::get_data(m_device_detector); + } +} full_chain_algorithm::~full_chain_algorithm() { @@ -108,19 +150,51 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()( // Run the clusterization (asynchronously). const clusterization_algorithm::output_type measurements = m_clusterization(cells_buffer, modules_buffer); + m_measurement_sorting(measurements); + + // Run the seed-finding (asynchronously). const spacepoint_formation_algorithm::output_type spacepoints = - m_spacepoint_formation(m_measurement_sorting(measurements), - modules_buffer); + m_spacepoint_formation(measurements, modules_buffer); const track_params_estimation::output_type track_params = m_track_parameter_estimation(spacepoints, m_seeding(spacepoints), - {0.f, 0.f, m_finder_config.bFieldInZ}); - - // Get the final data back to the host. - bound_track_parameters_collection_types::host result(&m_host_mr); - m_copy(track_params, result)->wait(); - - // Return the host container. - return result; + m_field_vec); + + // If we have a Detray detector, run the track finding and fitting. + if (m_detector != nullptr) { + + // Create the buffer needed by track finding and fitting. + auto navigation_buffer = detray::create_candidates_buffer( + *m_detector, + m_finding_config.navigation_buffer_size_scaler * + m_copy.get_size(track_params), + *m_cached_device_mr, &m_host_mr); + + // Run the track finding (asynchronously). + const finding_algorithm::output_type track_candidates = + m_finding(m_device_detector_view, m_field, navigation_buffer, + measurements, track_params); + + // Run the track fitting (asynchronously). + const fitting_algorithm::output_type track_states = + m_fitting(m_device_detector_view, m_field, navigation_buffer, + track_candidates); + + // Return the final container, copied back to the host. + return m_result_copy(track_states); + + } + // If not, copy the track parameters back to the host, and return a dummy + // object. + else { + + // Copy the track parameters back to the host. + bound_track_parameters_collection_types::host track_params_host( + &m_host_mr); + m_copy(track_params, track_params_host)->wait(); + + // Return an empty object. + return {}; + } } } // namespace traccc::cuda diff --git a/examples/run/cuda/full_chain_algorithm.hpp b/examples/run/cuda/full_chain_algorithm.hpp index f5b6bba72b..c2c7b2e817 100644 --- a/examples/run/cuda/full_chain_algorithm.hpp +++ b/examples/run/cuda/full_chain_algorithm.hpp @@ -11,13 +11,24 @@ #include "traccc/cuda/clusterization/clusterization_algorithm.hpp" #include "traccc/cuda/clusterization/measurement_sorting_algorithm.hpp" #include "traccc/cuda/clusterization/spacepoint_formation_algorithm.hpp" +#include "traccc/cuda/finding/finding_algorithm.hpp" +#include "traccc/cuda/fitting/fitting_algorithm.hpp" #include "traccc/cuda/seeding/seeding_algorithm.hpp" #include "traccc/cuda/seeding/track_params_estimation.hpp" #include "traccc/cuda/utils/stream.hpp" -#include "traccc/device/container_h2d_copy_alg.hpp" +#include "traccc/device/container_d2h_copy_alg.hpp" #include "traccc/edm/cell.hpp" +#include "traccc/edm/track_state.hpp" +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" #include "traccc/utils/algorithm.hpp" +// Detray include(s). +#include "detray/core/detector.hpp" +#include "detray/detectors/bfield.hpp" +#include "detray/navigation/navigator.hpp" +#include "detray/propagator/propagator.hpp" +#include "detray/propagator/rk_stepper.hpp" + // VecMem include(s). #include #include @@ -33,12 +44,39 @@ namespace traccc::cuda { /// /// At least as much as is implemented in the project at any given moment. /// -class full_chain_algorithm - : public algorithm { +class full_chain_algorithm : public algorithm { public: + /// @name Type declaration(s) + /// @{ + + /// (Host) Detector type used during track finding and fitting + using host_detector_type = detray::detector; + /// (Device) Detector type used during track finding and fitting + using device_detector_type = + detray::detector; + + /// Stepper type used by the track finding and fitting algorithms + using stepper_type = + detray::rk_stepper>; + /// Navigator type used by the track finding and fitting algorithms + using navigator_type = detray::navigator; + + /// Track finding algorithm type + using finding_algorithm = + traccc::cuda::finding_algorithm; + /// Track fitting algorithm type + using fitting_algorithm = traccc::cuda::fitting_algorithm< + traccc::kalman_fitter>; + + /// @} + /// Algorithm constructor /// /// @param mr The memory resource to use for the intermediate and result @@ -50,7 +88,10 @@ class full_chain_algorithm const unsigned short target_cells_per_partiton, const seedfinder_config& finder_config, const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config); + const seedfilter_config& filter_config, + const finding_algorithm::config_type& finding_config, + const fitting_algorithm::config_type& fitting_config, + host_detector_type* detector); /// Copy constructor /// @@ -86,6 +127,18 @@ class full_chain_algorithm /// (Asynchronous) Memory copy object mutable vecmem::cuda::async_copy m_copy; + /// Constant B field for the (seed) track parameter estimation + traccc::vector3 m_field_vec; + /// Constant B field for the track finding and fitting + detray::bfield::const_field_t m_field; + + /// Host detector + host_detector_type* m_detector; + /// Buffer holding the detector's payload on the device + host_detector_type::buffer_type m_device_detector; + /// View of the detector's payload on the device + host_detector_type::view_type m_device_detector_view; + /// @name Sub-algorithms used by this full-chain algorithm /// @{ @@ -103,11 +156,31 @@ class full_chain_algorithm /// Track parameter estimation algorithm track_params_estimation m_track_parameter_estimation; - /// Configs + /// Track finding algorithm + finding_algorithm m_finding; + /// Track fitting algorithm + fitting_algorithm m_fitting; + + /// Algorithm copying the result container back to the host + device::container_d2h_copy_alg m_result_copy; + + /// @} + + /// @name Algorithm configurations + /// @{ + + /// Configuration for the seed finding seedfinder_config m_finder_config; + /// Configuration for the spacepoint grid formation spacepoint_grid_config m_grid_config; + /// Configuration for the seed filtering seedfilter_config m_filter_config; + /// Configuration for the track finding + finding_algorithm::config_type m_finding_config; + /// Configuration for the track fitting + fitting_algorithm::config_type m_fitting_config; + /// @} }; // class full_chain_algorithm diff --git a/examples/run/sycl/CMakeLists.txt b/examples/run/sycl/CMakeLists.txt index 273c04ed52..521283221f 100644 --- a/examples/run/sycl/CMakeLists.txt +++ b/examples/run/sycl/CMakeLists.txt @@ -1,6 +1,6 @@ # TRACCC library, part of the ACTS project (R&D line) # -# (c) 2021-2023 CERN for the benefit of the ACTS project +# (c) 2021-2024 CERN for the benefit of the ACTS project # # Mozilla Public License Version 2.0 @@ -35,14 +35,17 @@ add_library( traccc_examples_sycl OBJECT "full_chain_algorithm.hpp" "full_chain_algorithm.sycl" ) target_link_libraries( traccc_examples_sycl - PUBLIC vecmem::core vecmem::sycl traccc::core traccc::device_common traccc::sycl ) + PUBLIC vecmem::core vecmem::sycl detray::core detray::utils + traccc::core traccc::device_common traccc::sycl ) traccc_add_executable( throughput_st_sycl "throughput_st.cpp" - LINK_LIBRARIES vecmem::core vecmem::sycl traccc::io traccc::performance + LINK_LIBRARIES vecmem::core vecmem::sycl detray::utils detray::io + traccc::io traccc::performance traccc::core traccc::device_common traccc::sycl traccc::options traccc_examples_sycl ) traccc_add_executable( throughput_mt_sycl "throughput_mt.cpp" - LINK_LIBRARIES TBB::tbb vecmem::core vecmem::sycl traccc::io traccc::performance + LINK_LIBRARIES TBB::tbb vecmem::core vecmem::sycl detray::utils detray::io + traccc::io traccc::performance traccc::core traccc::device_common traccc::sycl traccc::options traccc_examples_sycl ) diff --git a/examples/run/sycl/full_chain_algorithm.hpp b/examples/run/sycl/full_chain_algorithm.hpp index 3c19ea4989..93ae91ea48 100644 --- a/examples/run/sycl/full_chain_algorithm.hpp +++ b/examples/run/sycl/full_chain_algorithm.hpp @@ -9,12 +9,22 @@ // Project include(s). #include "traccc/edm/cell.hpp" +#include "traccc/finding/finding_algorithm.hpp" +#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" #include "traccc/sycl/clusterization/clusterization_algorithm.hpp" #include "traccc/sycl/clusterization/spacepoint_formation_algorithm.hpp" #include "traccc/sycl/seeding/seeding_algorithm.hpp" #include "traccc/sycl/seeding/track_params_estimation.hpp" #include "traccc/utils/algorithm.hpp" +// Detray include(s). +#include "detray/core/detector.hpp" +#include "detray/detectors/bfield.hpp" +#include "detray/navigation/navigator.hpp" +#include "detray/propagator/propagator.hpp" +#include "detray/propagator/rk_stepper.hpp" + // VecMem include(s). #include #include @@ -40,6 +50,30 @@ class full_chain_algorithm const cell_module_collection_types::host&)> { public: + /// @name (For now dummy...) Type declaration(s) + /// @{ + + /// Detector type used during track finding and fitting + using detector_type = detray::detector; + + /// Stepper type used by the track finding and fitting algorithms + using stepper_type = + detray::rk_stepper>; + /// Navigator type used by the track finding and fitting algorithms + using navigator_type = detray::navigator; + + /// Track finding algorithm type + using finding_algorithm = + traccc::finding_algorithm; + /// Track fitting algorithm type + using fitting_algorithm = traccc::fitting_algorithm< + traccc::kalman_fitter>; + + /// @} + /// Algorithm constructor /// /// @param mr The memory resource to use for the intermediate and result @@ -47,11 +81,15 @@ class full_chain_algorithm /// @param target_cells_per_partition The average number of cells in each /// partition. /// - full_chain_algorithm(vecmem::memory_resource& host_mr, - const unsigned short target_cells_per_partition, - const seedfinder_config& finder_config, - const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config); + full_chain_algorithm( + vecmem::memory_resource& host_mr, + const unsigned short target_cells_per_partition, + const seedfinder_config& finder_config, + const spacepoint_grid_config& grid_config, + const seedfilter_config& filter_config, + const finding_algorithm::config_type& finding_config = {}, + const fitting_algorithm::config_type& fitting_config = {}, + detector_type* detector = nullptr); /// Copy constructor /// diff --git a/examples/run/sycl/full_chain_algorithm.sycl b/examples/run/sycl/full_chain_algorithm.sycl index 66e50fa521..839ce39180 100644 --- a/examples/run/sycl/full_chain_algorithm.sycl +++ b/examples/run/sycl/full_chain_algorithm.sycl @@ -45,7 +45,9 @@ full_chain_algorithm::full_chain_algorithm( const unsigned short target_cells_per_partition, const seedfinder_config& finder_config, const spacepoint_grid_config& grid_config, - const seedfilter_config& filter_config) + const seedfilter_config& filter_config, + const finding_algorithm::config_type&, + const fitting_algorithm::config_type&, detector_type*) : m_data(new details::full_chain_algorithm_data{{::handle_async_error}}), m_host_mr(host_mr), m_device_mr(std::make_unique(