Skip to content

Commit

Permalink
Merge pull request #626 from stephenswat/feat/cli_to_config
Browse files Browse the repository at this point in the history
Add CLI options to configuration type conversion
  • Loading branch information
stephenswat authored Jun 24, 2024
2 parents f22a970 + 711d554 commit 0ef502d
Show file tree
Hide file tree
Showing 22 changed files with 150 additions and 148 deletions.
17 changes: 10 additions & 7 deletions examples/options/include/traccc/options/clusterization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,29 @@
#pragma once

// Local include(s).
#include "traccc/options/details/config_provider.hpp"
#include "traccc/options/details/interface.hpp"

namespace traccc::opts {

/// Options for the cell clusterization algorithm(s)
class clusterization : public interface {
class clusterization : public interface,
public config_provider<unsigned short> {

public:
/// Constructor
clusterization();

/// Configuration conversion
operator unsigned short() const override;

private:
/// @name Options
/// @{

/// The number of cells to merge in a partition
unsigned short target_cells_per_partition = 1024;

/// @}

/// Constructor
clusterization();

private:
/// Print the specific options of this class
std::ostream& print_impl(std::ostream& out) const override;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

namespace traccc::opts {
/**
* @brief Mixin type to indicate that some set of program options can be
* converted to some configuration type.
*
* @tparam Config The config type to which this can be converted
*/
template <typename Config>
class config_provider {
public:
using config_type = Config;

virtual operator config_type() const = 0;
};
} // namespace traccc::opts
20 changes: 13 additions & 7 deletions examples/options/include/traccc/options/track_finding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#pragma once

// Project include(s).
#include "traccc/finding/finding_config.hpp"
#include "traccc/options/details/config_provider.hpp"
#include "traccc/options/details/interface.hpp"
#include "traccc/options/details/value_array.hpp"

Expand All @@ -20,12 +22,21 @@
namespace traccc::opts {

/// Configuration for track finding
class track_finding : public interface {
class track_finding : public interface,
public config_provider<finding_config<float>>,
public config_provider<finding_config<double>> {

public:
/// Constructor
track_finding();

/// Configuration conversion operators
operator finding_config<float>() const override;
operator finding_config<double>() const override;

private:
/// @name Options
/// @{

/// Number of track candidates per seed
opts::value_array<unsigned int, 2> track_candidates_range{3, 100};
/// Minimum step length that track should make to reach the next surface. It
Expand All @@ -40,13 +51,8 @@ class track_finding : public interface {
unsigned int nmax_per_seed = 10;
/// Maximum allowed number of skipped steps per candidate
unsigned int max_num_skipping_per_cand = 3;

/// @}

/// Constructor
track_finding();

private:
/// Print the specific options of this class
std::ostream& print_impl(std::ostream& out) const override;

Expand Down
21 changes: 12 additions & 9 deletions examples/options/include/traccc/options/track_propagation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#pragma once

// Local include(s).
#include "traccc/options/details/config_provider.hpp"
#include "traccc/options/details/interface.hpp"
#include "traccc/options/details/value_array.hpp"

Expand All @@ -17,17 +18,10 @@
namespace traccc::opts {

/// Command line options used in the propagation tests
class track_propagation : public interface {
class track_propagation : public interface,
public config_provider<detray::propagation::config> {

public:
/// @name Options
/// @{

/// Propagation configuration object
detray::propagation::config config;

/// @}

/// Constructor
track_propagation();

Expand All @@ -37,7 +31,16 @@ class track_propagation : public interface {
///
void read(const boost::program_options::variables_map& vm) override;

/// Configuration provider
operator detray::propagation::config() const override;

private:
/// @name Options
/// @{
/// Propagation configuration object
detray::propagation::config config;
/// @}

/// Print the specific options of this class
std::ostream& print_impl(std::ostream& out) const override;

Expand Down
4 changes: 4 additions & 0 deletions examples/options/src/clusterization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ clusterization::clusterization() : interface("Clusterization Options") {
"The number of cells to merge in a partition");
}

clusterization::operator unsigned short() const {
return target_cells_per_partition;
}

std::ostream& clusterization::print_impl(std::ostream& out) const {

out << " Target cells per partition: " << target_cells_per_partition;
Expand Down
24 changes: 24 additions & 0 deletions examples/options/src/track_finding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ track_finding::track_finding() : interface("Track Finding Options") {
"Maximum allowed number of skipped steps per candidate");
}

track_finding::operator finding_config<float>() const {
finding_config<float> out;
out.min_track_candidates_per_track = track_candidates_range[0];
out.max_track_candidates_per_track = track_candidates_range[1];
out.min_step_length_for_next_surface = min_step_length_for_next_surface;
out.max_step_counts_for_next_surface = max_step_counts_for_next_surface;
out.chi2_max = chi2_max;
out.max_num_branches_per_seed = nmax_per_seed;
out.max_num_skipping_per_cand = max_num_skipping_per_cand;
return out;
}

track_finding::operator finding_config<double>() const {
finding_config<double> out;
out.min_track_candidates_per_track = track_candidates_range[0];
out.max_track_candidates_per_track = track_candidates_range[1];
out.min_step_length_for_next_surface = min_step_length_for_next_surface;
out.max_step_counts_for_next_surface = max_step_counts_for_next_surface;
out.chi2_max = chi2_max;
out.max_num_branches_per_seed = nmax_per_seed;
out.max_num_skipping_per_cand = max_num_skipping_per_cand;
return out;
}

std::ostream& track_finding::print_impl(std::ostream& out) const {

out << " Track candidates range : " << track_candidates_range << "\n"
Expand Down
4 changes: 4 additions & 0 deletions examples/options/src/track_propagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ void track_propagation::read(const po::variables_map&) {
config.navigation.search_window = m_search_window;
}

track_propagation::operator detray::propagation::config() const {
return config;
}

std::ostream& track_propagation::print_impl(std::ostream& out) const {

out << config;
Expand Down
23 changes: 6 additions & 17 deletions examples/run/common/throughput_mt.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,13 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
cached_host_mrs{threading_opts.threads + 1};

// Algorithm configuration(s).
typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg;
finding_cfg.min_track_candidates_per_track =
finding_opts.track_candidates_range[0];
finding_cfg.max_track_candidates_per_track =
finding_opts.track_candidates_range[1];
finding_cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
finding_cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
finding_cfg.chi2_max = finding_opts.chi2_max;
finding_cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
finding_cfg.max_num_skipping_per_cand =
finding_opts.max_num_skipping_per_cand;
finding_cfg.propagation = propagation_opts.config;
detray::propagation::config propagation_config(propagation_opts);
typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg(
finding_opts);
finding_cfg.propagation = propagation_config;

typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg;
fitting_cfg.propagation = propagation_opts.config;
fitting_cfg.propagation = propagation_config;

// Set up the full-chain algorithm(s). One for each thread.
std::vector<FULL_CHAIN_ALG> algs;
Expand All @@ -170,7 +160,7 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
: static_cast<vecmem::memory_resource&>(uncached_host_mr);
algs.push_back(
{alg_host_mr,
clusterization_opts.target_cells_per_partition,
clusterization_opts,
seeding_opts.seedfinder,
{seeding_opts.seedfinder},
seeding_opts.seedfilter,
Expand Down Expand Up @@ -267,7 +257,6 @@ int throughput_mt(std::string_view description, int argc, char* argv[],
<< "," << threading_opts.threads << "," << input_opts.events
<< "," << throughput_opts.cold_run_events << ","
<< throughput_opts.processed_events << ","
<< clusterization_opts.target_cells_per_partition << ","
<< times.get_time("Warm-up processing").count() << ","
<< times.get_time("Event processing").count() << std::endl;
logFile.close();
Expand Down
23 changes: 6 additions & 17 deletions examples/run/common/throughput_st.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -122,28 +122,17 @@ int throughput_st(std::string_view description, int argc, char* argv[],
}

// Algorithm configuration(s).
typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg;
finding_cfg.min_track_candidates_per_track =
finding_opts.track_candidates_range[0];
finding_cfg.max_track_candidates_per_track =
finding_opts.track_candidates_range[1];
finding_cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
finding_cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
finding_cfg.chi2_max = finding_opts.chi2_max;
finding_cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
finding_cfg.max_num_skipping_per_cand =
finding_opts.max_num_skipping_per_cand;
finding_cfg.propagation = propagation_opts.config;
detray::propagation::config propagation_config(propagation_opts);
typename FULL_CHAIN_ALG::finding_algorithm::config_type finding_cfg(
finding_opts);
finding_cfg.propagation = propagation_config;

typename FULL_CHAIN_ALG::fitting_algorithm::config_type fitting_cfg;
fitting_cfg.propagation = propagation_opts.config;
fitting_cfg.propagation = propagation_config;

// Set up the full-chain algorithm.
std::unique_ptr<FULL_CHAIN_ALG> alg = std::make_unique<FULL_CHAIN_ALG>(
alg_host_mr, clusterization_opts.target_cells_per_partition,
seeding_opts.seedfinder,
alg_host_mr, clusterization_opts, seeding_opts.seedfinder,
spacepoint_grid_config{seeding_opts.seedfinder},
seeding_opts.seedfilter, finding_cfg, fitting_cfg,
(detector_opts.use_detray_detector ? &detector : nullptr));
Expand Down
22 changes: 8 additions & 14 deletions examples/run/cpu/seeding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,21 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
seeding_opts.seedfilter, host_mr);
traccc::track_params_estimation tp(host_mr);

// Propagation configuration
detray::propagation::config propagation_config(propagation_opts);

// Finding algorithm configuration
typename traccc::finding_algorithm<rk_stepper_type,
host_navigator_type>::config_type cfg;

cfg.min_track_candidates_per_track = finding_opts.track_candidates_range[0];
cfg.max_track_candidates_per_track = finding_opts.track_candidates_range[1];
cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
cfg.chi2_max = finding_opts.chi2_max;
cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
cfg.max_num_skipping_per_cand = finding_opts.max_num_skipping_per_cand;
cfg.propagation = propagation_opts.config;
typename traccc::finding_algorithm<
rk_stepper_type, host_navigator_type>::config_type cfg(finding_opts);

cfg.propagation = propagation_config;

traccc::finding_algorithm<rk_stepper_type, host_navigator_type>
host_finding(cfg);

// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
fit_cfg.propagation = propagation_opts.config;
fit_cfg.propagation = propagation_config;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);

Expand Down
20 changes: 5 additions & 15 deletions examples/run/cpu/seq_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,13 @@ int seq_run(const traccc::opts::input_data& input_opts,
detray::bfield::create_const_field(field_vec);

// Algorithm configuration(s).
finding_algorithm::config_type finding_cfg;
finding_cfg.min_track_candidates_per_track =
finding_opts.track_candidates_range[0];
finding_cfg.max_track_candidates_per_track =
finding_opts.track_candidates_range[1];
finding_cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
finding_cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
finding_cfg.chi2_max = finding_opts.chi2_max;
finding_cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
finding_cfg.max_num_skipping_per_cand =
finding_opts.max_num_skipping_per_cand;
finding_cfg.propagation = propagation_opts.config;
detray::propagation::config propagation_config(propagation_opts);

finding_algorithm::config_type finding_cfg(finding_opts);
finding_cfg.propagation = propagation_config;

fitting_algorithm::config_type fitting_cfg;
fitting_cfg.propagation = propagation_opts.config;
fitting_cfg.propagation = propagation_config;

// Algorithms
traccc::host::clusterization_algorithm ca(host_mr);
Expand Down
20 changes: 7 additions & 13 deletions examples/run/cpu/truth_finding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,21 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
1e-4 / detray::unit<traccc::scalar>::GeV,
1e-4 * detray::unit<traccc::scalar>::ns};

// Propagation configuration
detray::propagation::config propagation_config(propagation_opts);

// Finding algorithm configuration
typename traccc::finding_algorithm<rk_stepper_type,
host_navigator_type>::config_type cfg;
cfg.min_track_candidates_per_track = finding_opts.track_candidates_range[0];
cfg.max_track_candidates_per_track = finding_opts.track_candidates_range[1];
cfg.min_step_length_for_next_surface =
finding_opts.min_step_length_for_next_surface;
cfg.max_step_counts_for_next_surface =
finding_opts.max_step_counts_for_next_surface;
cfg.chi2_max = finding_opts.chi2_max;
cfg.max_num_branches_per_seed = finding_opts.nmax_per_seed;
cfg.max_num_skipping_per_cand = finding_opts.max_num_skipping_per_cand;
cfg.propagation = propagation_opts.config;
typename traccc::finding_algorithm<
rk_stepper_type, host_navigator_type>::config_type cfg(finding_opts);
cfg.propagation = propagation_config;

// Finding algorithm object
traccc::finding_algorithm<rk_stepper_type, host_navigator_type>
host_finding(cfg);

// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
fit_cfg.propagation = propagation_opts.config;
fit_cfg.propagation = propagation_config;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);

Expand Down
2 changes: 1 addition & 1 deletion examples/run/cpu/truth_fitting_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ int main(int argc, char* argv[]) {

// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
fit_cfg.propagation = propagation_opts.config;
fit_cfg.propagation = propagation_opts;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);

Expand Down
Loading

0 comments on commit 0ef502d

Please sign in to comment.