-
Notifications
You must be signed in to change notification settings - Fork 173
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: ML based seed filtering (#2709)
This pull request add a ML based seed selection that can be used to reduce the number of seed after the seeding step. This PR is pending on the merging of #2690.
- Loading branch information
1 parent
66c671c
commit 61bc4e1
Showing
32 changed files
with
1,532 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
...s/Algorithms/TrackFindingML/include/ActsExamples/TrackFindingML/SeedFilterMLAlgorithm.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2023 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#pragma once | ||
|
||
#include "Acts/Plugins/Onnx/SeedClassifier.hpp" | ||
#include "ActsExamples/EventData/SimSeed.hpp" | ||
#include "ActsExamples/EventData/Track.hpp" | ||
#include "ActsExamples/Framework/DataHandle.hpp" | ||
#include "ActsExamples/TrackFindingML/AmbiguityResolutionML.hpp" | ||
|
||
#include <string> | ||
|
||
namespace ActsExamples { | ||
|
||
/// Removes seeds that seem to be duplicated and fake. | ||
/// | ||
/// The implementation works as follows: | ||
/// 1) Cluster together nearby seeds using a DBScan | ||
/// 2) For each seed use a neural network to compute a score | ||
/// 3) In each cluster keep the seed with the highest score | ||
class SeedFilterMLAlgorithm : public IAlgorithm { | ||
public: | ||
struct Config { | ||
/// Input estimated track parameters collection. | ||
std::string inputTrackParameters; | ||
/// Input seeds collection. | ||
std::string inputSimSeeds; | ||
/// Path to the ONNX model for the duplicate neural network | ||
std::string inputSeedFilterNN; | ||
/// Output estimated track parameters collection. | ||
std::string outputTrackParameters; | ||
/// Output seeds collection. | ||
std::string outputSimSeeds; | ||
/// Maximum distance between 2 tracks to be clustered in the DBScan | ||
float epsilonDBScan = 0.03; | ||
/// Minimum number of tracks to create a cluster in the DBScan | ||
int minPointsDBScan = 2; | ||
/// Minimum score a seed need to be selected | ||
float minSeedScore = 0.1; | ||
/// Clustering parameters weight for phi used before the DBSCAN | ||
double clusteringWeighPhi = 1.0; | ||
/// Clustering parameters weight for eta used before the DBSCAN | ||
double clusteringWeighEta = 1.0; | ||
/// Clustering parameters weight for z used before the DBSCAN | ||
double clusteringWeighZ = 50.0; | ||
/// Clustering parameters weight for pT used before the DBSCAN | ||
double clusteringWeighPt = 1.0; | ||
}; | ||
|
||
/// Construct the seed filter algorithm. | ||
/// | ||
/// @param cfg is the algorithm configuration | ||
/// @param lvl is the logging level | ||
SeedFilterMLAlgorithm(Config cfg, Acts::Logging::Level lvl); | ||
|
||
/// Run the seed filter algorithm. | ||
/// | ||
/// @param cxt is the algorithm context with event information | ||
/// @return a process code indication success or failure | ||
ProcessCode execute(const AlgorithmContext& ctx) const final; | ||
|
||
/// Const access to the config | ||
const Config& config() const { return m_cfg; } | ||
|
||
private: | ||
Config m_cfg; | ||
// ONNX model for track selection | ||
Acts::SeedClassifier m_seedClassifier; | ||
ReadDataHandle<TrackParametersContainer> m_inputTrackParameters{ | ||
this, "InputTrackParameters"}; | ||
ReadDataHandle<SimSeedContainer> m_inputSimSeeds{this, "InputSimSeeds"}; | ||
WriteDataHandle<TrackParametersContainer> m_outputTrackParameters{ | ||
this, "OutputTrackParameters"}; | ||
WriteDataHandle<SimSeedContainer> m_outputSimSeeds{this, "OutputSimSeeds"}; | ||
}; | ||
|
||
} // namespace ActsExamples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
Examples/Algorithms/TrackFindingML/src/SeedFilterMLAlgorithm.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2023 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#include "ActsExamples/TrackFindingML/SeedFilterMLAlgorithm.hpp" | ||
|
||
#include "Acts/Plugins/Mlpack/SeedFilterDBScanClustering.hpp" | ||
#include "ActsExamples/Framework/ProcessCode.hpp" | ||
#include "ActsExamples/Framework/WhiteBoard.hpp" | ||
|
||
#include <iterator> | ||
#include <map> | ||
|
||
ActsExamples::SeedFilterMLAlgorithm::SeedFilterMLAlgorithm( | ||
ActsExamples::SeedFilterMLAlgorithm::Config cfg, Acts::Logging::Level lvl) | ||
: ActsExamples::IAlgorithm("SeedFilterMLAlgorithm", lvl), | ||
m_cfg(std::move(cfg)), | ||
m_seedClassifier(m_cfg.inputSeedFilterNN.c_str()) { | ||
if (m_cfg.inputTrackParameters.empty()) { | ||
throw std::invalid_argument("Missing track parameters input collection"); | ||
} | ||
if (m_cfg.inputSimSeeds.empty()) { | ||
throw std::invalid_argument("Missing seed input collection"); | ||
} | ||
if (m_cfg.outputTrackParameters.empty()) { | ||
throw std::invalid_argument("Missing track parameters output collection"); | ||
} | ||
if (m_cfg.outputSimSeeds.empty()) { | ||
throw std::invalid_argument("Missing seed output collection"); | ||
} | ||
m_inputTrackParameters.initialize(m_cfg.inputTrackParameters); | ||
m_inputSimSeeds.initialize(m_cfg.inputSimSeeds); | ||
m_outputTrackParameters.initialize(m_cfg.outputTrackParameters); | ||
m_outputSimSeeds.initialize(m_cfg.outputSimSeeds); | ||
} | ||
|
||
ActsExamples::ProcessCode ActsExamples::SeedFilterMLAlgorithm::execute( | ||
const AlgorithmContext& ctx) const { | ||
// Read input data | ||
const auto& seeds = m_inputSimSeeds(ctx); | ||
const auto& params = m_inputTrackParameters(ctx); | ||
if (seeds.size() != params.size()) { | ||
throw std::invalid_argument( | ||
"The number of seeds and track parameters is different"); | ||
} | ||
|
||
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> | ||
networkInput(seeds.size(), 14); | ||
std::vector<std::vector<double>> clusteringParams; | ||
// Loop over the seed and parameters to fill the input for the clustering | ||
// and the NN | ||
for (std::size_t i = 0; i < seeds.size(); i++) { | ||
// Compute the track parameters | ||
double pT = std::abs(1.0 / params[i].parameters()[Acts::eBoundQOverP]) * | ||
std::sin(params[i].parameters()[Acts::eBoundTheta]); | ||
double eta = | ||
std::atanh(std::cos(params[i].parameters()[Acts::eBoundTheta])); | ||
double phi = params[i].parameters()[Acts::eBoundPhi]; | ||
|
||
// Fill and weight the clustering inputs | ||
clusteringParams.push_back( | ||
{phi / m_cfg.clusteringWeighPhi, eta / m_cfg.clusteringWeighEta, | ||
seeds[i].z() / m_cfg.clusteringWeighZ, pT / m_cfg.clusteringWeighPt}); | ||
|
||
// Fill the NN input | ||
networkInput.row(i) << pT, eta, phi, seeds[i].sp()[0]->x(), | ||
seeds[i].sp()[0]->y(), seeds[i].sp()[0]->z(), seeds[i].sp()[1]->x(), | ||
seeds[i].sp()[1]->y(), seeds[i].sp()[1]->z(), seeds[i].sp()[2]->x(), | ||
seeds[i].sp()[2]->y(), seeds[i].sp()[2]->z(), seeds[i].z(), | ||
seeds[i].seedQuality(); | ||
} | ||
|
||
// Cluster the tracks using DBscan | ||
auto cluster = Acts::dbscanSeedClustering( | ||
clusteringParams, m_cfg.epsilonDBScan, m_cfg.minPointsDBScan); | ||
|
||
// Select the ID of the track we want to keep | ||
std::vector<std::size_t> goodSeed = m_seedClassifier.solveAmbiguity( | ||
cluster, networkInput, m_cfg.minSeedScore); | ||
|
||
// Create the output seed collection | ||
SimSeedContainer outputSeeds; | ||
outputSeeds.reserve(goodSeed.size()); | ||
|
||
// Create the output track parameters collection | ||
TrackParametersContainer outputTrackParameters; | ||
outputTrackParameters.reserve(goodSeed.size()); | ||
|
||
for (auto i : goodSeed) { | ||
outputSeeds.push_back(seeds[i]); | ||
outputTrackParameters.push_back(params[i]); | ||
} | ||
|
||
m_outputSimSeeds(ctx, SimSeedContainer{outputSeeds}); | ||
m_outputTrackParameters(ctx, TrackParametersContainer{outputTrackParameters}); | ||
|
||
return ActsExamples::ProcessCode::SUCCESS; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.