Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ML based seed filtering #2709

Merged
merged 43 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
17195aa
seed writting
Corentin-Allaire Nov 17, 2023
79e1b66
format
Corentin-Allaire Nov 17, 2023
8da4721
comments
Corentin-Allaire Nov 17, 2023
60acd9a
cmake
Corentin-Allaire Nov 17, 2023
09e9d4a
Merge remote-tracking branch 'upstream/main' into SeedWriting
Corentin-Allaire Nov 17, 2023
b38cc70
format
Corentin-Allaire Nov 17, 2023
e832ee2
size_t
Corentin-Allaire Nov 17, 2023
b0224dd
Apply suggestions from code review
Corentin-Allaire Nov 20, 2023
2a5d15b
SeedInfo
Corentin-Allaire Nov 20, 2023
2506889
fixed sized vector
Corentin-Allaire Nov 21, 2023
3d10132
Merge remote-tracking branch 'upstream/main' into SeedWriting
Corentin-Allaire Nov 21, 2023
75fe6f6
Merge remote-tracking branch 'origin/SeedWriting' into SeedWriting
Corentin-Allaire Nov 21, 2023
b2a9cf3
new files
Corentin-Allaire Nov 21, 2023
1e6a93c
update CMakelists
Corentin-Allaire Nov 21, 2023
acedcc3
algorithms
Corentin-Allaire Nov 21, 2023
238d49e
plugin onnx
Corentin-Allaire Nov 21, 2023
9f1116a
python files
Corentin-Allaire Nov 21, 2023
9d6979b
python bindings
Corentin-Allaire Nov 21, 2023
7de51a7
Merge branch 'SeedWriting' into MLSeedFilter
Corentin-Allaire Nov 21, 2023
8038cf0
format
Corentin-Allaire Nov 21, 2023
fca347a
Merge remote-tracking branch 'upstream/main' into MLSeedFilter
Corentin-Allaire Nov 21, 2023
f992d89
onnx file
Corentin-Allaire Nov 22, 2023
aa63d0d
bugfix
Corentin-Allaire Nov 24, 2023
df635fb
Merge remote-tracking branch 'upstream/main' into MLSeedFilter
Corentin-Allaire Nov 24, 2023
4670b5c
update initialised value
Corentin-Allaire Dec 5, 2023
27e1365
conflict
Corentin-Allaire Dec 5, 2023
2801ce2
working
Corentin-Allaire Dec 6, 2023
397cc2e
Merge remote-tracking branch 'upstream/main' into MLSeedFilter
Corentin-Allaire Dec 6, 2023
791498d
Greatly reduce trainning mem consuption
Corentin-Allaire Dec 7, 2023
10dd25c
Merge remote-tracking branch 'upstream/main' into MLSeedFilter
Corentin-Allaire Dec 7, 2023
b2926a0
remove useless part of matching
Corentin-Allaire Dec 7, 2023
fe3addb
improved the network
Corentin-Allaire Dec 11, 2023
d9a363e
doc
Corentin-Allaire Dec 11, 2023
30928ad
Merge remote-tracking branch 'upstream/main' into MLSeedFilter
Corentin-Allaire Dec 11, 2023
74f5a94
PR comments
Corentin-Allaire Dec 14, 2023
89b0d7c
Merge remote-tracking branch 'upstream/main' into MLSeedFilter
Corentin-Allaire Dec 14, 2023
fa1abc5
Apply suggestions from code review
Corentin-Allaire Dec 14, 2023
c70cce9
Apply suggestions from code review
Corentin-Allaire Dec 14, 2023
52a2d2a
fix doc
Corentin-Allaire Dec 14, 2023
a962269
Merge remote-tracking branch 'origin/MLSeedFilter' into MLSeedFilter
Corentin-Allaire Dec 14, 2023
3018e47
format
Corentin-Allaire Dec 14, 2023
d94a63f
remove data=data
Corentin-Allaire Dec 14, 2023
9297395
remove print
Corentin-Allaire Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ namespace detail {
///
/// @param trackMap : Multimap storing pair of track ID and vector of measurement ID. The keys are the number of measurement and are just there to facilitate the ordering.
/// @return an unordered map representing the clusters, the keys the ID of the primary track of each cluster and the store a vector of track IDs.
std::unordered_map<int, std::vector<int>> clusterDuplicateTracks(
const std::multimap<int, std::pair<int, std::vector<int>>>& trackMap);
std::unordered_map<std::size_t, std::vector<std::size_t>>
clusterDuplicateTracks(
const std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>&
trackMap);

} // namespace detail
} // namespace Acts
14 changes: 8 additions & 6 deletions Core/src/TrackFinding/AmbiguityTrackClustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@

#include <iterator>

std::unordered_map<int, std::vector<int>> Acts::detail::clusterDuplicateTracks(
const std::multimap<int, std::pair<int, std::vector<int>>>& trackMap) {
std::unordered_map<std::size_t, std::vector<std::size_t>>
Acts::detail::clusterDuplicateTracks(
const std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>&
trackMap) {
// Unordered map associating a vector with all the track ID of a cluster to
// the ID of the first track of the cluster
std::unordered_map<int, std::vector<int>> cluster;
std::unordered_map<std::size_t, std::vector<std::size_t>> cluster;
// Unordered map associating hits to the ID of the first track of the
// different clusters.
std::unordered_map<int, int> hitToTrack;
std::unordered_map<std::size_t, std::size_t> hitToTrack;

// Loop over all the tracks
for (auto track = trackMap.rbegin(); track != trackMap.rend(); ++track) {
std::vector<int> hits = track->second.second;
std::vector<std::size_t> hits = track->second.second;
auto matchedTrack = hitToTrack.end();
// Loop over all the hits in the track
for (auto hit = hits.begin(); hit != hits.end(); hit++) {
Expand All @@ -36,7 +38,7 @@ std::unordered_map<int, std::vector<int>> Acts::detail::clusterDuplicateTracks(
// None of the hits have been matched to a track create a new cluster
if (matchedTrack == hitToTrack.end()) {
cluster.emplace(track->second.first,
std::vector<int>(1, track->second.first));
std::vector<std::size_t>(1, track->second.first));
for (const auto& hit : hits) {
// Add the hits of the new cluster to the hitToTrack
hitToTrack.emplace(hit, track->second.first);
Expand Down
1 change: 1 addition & 0 deletions Examples/Algorithms/TrackFindingML/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(SOURCES
if(ACTS_BUILD_PLUGIN_MLPACK)
list(APPEND SOURCES
src/AmbiguityResolutionMLDBScanAlgorithm.cpp
src/SeedFilterMLAlgorithm.cpp
)
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ class AmbiguityResolutionML : public IAlgorithm {
/// @param tracks is the input track container
/// @param nMeasurementsMin minimum number of measurement per track
/// @return an ordered list containing pairs of track ID and associated measurement ID
std::multimap<int, std::pair<int, std::vector<int>>> mapTrackHits(
const ConstTrackContainer& tracks, int nMeasurementsMin) const;
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
mapTrackHits(const ConstTrackContainer& tracks, int nMeasurementsMin) const;

/// Prepare the output track container to be written
///
/// @param tracks is the input track container
/// @param goodTracks is list of the IDs of all the tracks we want to keep
ConstTrackContainer prepareOutputTrack(const ConstTrackContainer& tracks,
std::vector<int>& goodTracks) const;
ConstTrackContainer prepareOutputTrack(
const ConstTrackContainer& tracks,
std::vector<std::size_t>& goodTracks) const;
};

} // namespace ActsExamples
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Plugins/Onnx/SeedClassifier.hpp"
#include "ActsExamples/EventData/SimSeed.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/TrackFindingML/AmbiguityResolutionML.hpp"

#include <string>

namespace ActsExamples {

/// Removes seeds that seem to be duplicated and fake.
///
/// The implementation works as follows:
/// 1) Cluster together nearby seeds using a DBScan
/// 2) For each seed use a neural network to compute a score
/// 3) In each cluster keep the seed with the highest score
class SeedFilterMLAlgorithm : public IAlgorithm {
public:
struct Config {
/// Input estimated track parameters collection.
std::string inputTrackParameters;
/// Input seeds collection.
std::string inputSimSeeds;
/// Path to the ONNX model for the duplicate neural network
std::string inputSeedFilterNN;
/// Output estimated track parameters collection.
std::string outputTrackParameters;
/// Output seeds collection.
std::string outputSimSeeds;
/// Maximum distance between 2 tracks to be clustered in the DBScan
float epsilonDBScan = 0.03;
/// Minimum number of tracks to create a cluster in the DBScan
int minPointsDBScan = 2;
/// Minimum score a seed need to be selected
float minSeedScore = 0.1;
/// Clustering parameters weight for phi used before the DBSCAN
double clusteringWeighPhi = 1.0;
/// Clustering parameters weight for eta used before the DBSCAN
double clusteringWeighEta = 1.0;
/// Clustering parameters weight for z used before the DBSCAN
double clusteringWeighZ = 50.0;
/// Clustering parameters weight for pT used before the DBSCAN
double clusteringWeighPt = 1.0;
};

/// Construct the seed filter algorithm.
///
/// @param cfg is the algorithm configuration
/// @param lvl is the logging level
SeedFilterMLAlgorithm(Config cfg, Acts::Logging::Level lvl);

/// Run the seed filter algorithm.
///
/// @param cxt is the algorithm context with event information
/// @return a process code indication success or failure
ProcessCode execute(const AlgorithmContext& ctx) const final;

/// Const access to the config
const Config& config() const { return m_cfg; }

private:
Config m_cfg;
// ONNX model for track selection
Acts::SeedClassifier m_seedClassifier;
ReadDataHandle<TrackParametersContainer> m_inputTrackParameters{
this, "InputTrackParameters"};
ReadDataHandle<SimSeedContainer> m_inputSimSeeds{this, "InputSimSeeds"};
WriteDataHandle<TrackParametersContainer> m_outputTrackParameters{
this, "OutputTrackParameters"};
WriteDataHandle<SimSeedContainer> m_outputSimSeeds{this, "OutputSimSeeds"};
};

} // namespace ActsExamples
15 changes: 8 additions & 7 deletions Examples/Algorithms/TrackFindingML/src/AmbiguityResolutionML.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,24 @@ ActsExamples::AmbiguityResolutionML::AmbiguityResolutionML(
std::string name, Acts::Logging::Level lvl)
: ActsExamples::IAlgorithm(name, lvl) {}

std::multimap<int, std::pair<int, std::vector<int>>>
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
ActsExamples::AmbiguityResolutionML::mapTrackHits(
const ActsExamples::ConstTrackContainer& tracks,
int nMeasurementsMin) const {
std::multimap<int, std::pair<int, std::vector<int>>> trackMap;
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>> trackMap;
// Loop over all the trajectories in the events
for (const auto& track : tracks) {
std::vector<int> hits;
std::vector<std::size_t> hits;
int nbMeasurements = 0;
// Store the hits id for the trajectory and compute the number of
// measurement
tracks.trackStateContainer().visitBackwards(
track.tipIndex(), [&](const auto& state) {
if (state.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
int indexHit = state.getUncalibratedSourceLink()
.template get<ActsExamples::IndexSourceLink>()
.index();
std::size_t indexHit =
state.getUncalibratedSourceLink()
.template get<ActsExamples::IndexSourceLink>()
.index();
hits.emplace_back(indexHit);
++nbMeasurements;
}
Expand All @@ -47,7 +48,7 @@ ActsExamples::AmbiguityResolutionML::mapTrackHits(
ActsExamples::ConstTrackContainer
ActsExamples::AmbiguityResolutionML::prepareOutputTrack(
const ActsExamples::ConstTrackContainer& tracks,
std::vector<int>& goodTracks) const {
std::vector<std::size_t>& goodTracks) const {
std::shared_ptr<Acts::ConstVectorMultiTrajectory> trackStateContainer =
tracks.trackStateContainerHolder();
auto trackContainer = std::make_shared<Acts::VectorTrackContainer>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ ActsExamples::ProcessCode ActsExamples::AmbiguityResolutionMLAlgorithm::execute(
// Read input data
const auto& tracks = m_inputTracks(ctx);
// Associate measurement to their respective tracks
std::multimap<int, std::pair<int, std::vector<int>>> trackMap =
mapTrackHits(tracks, m_cfg.nMeasurementsMin);
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
trackMap = mapTrackHits(tracks, m_cfg.nMeasurementsMin);
auto cluster = Acts::detail::clusterDuplicateTracks(trackMap);
// Select the ID of the track we want to keep
std::vector<int> goodTracks =
m_duplicateClassifier.solveAmbuguity(cluster, tracks);
std::vector<std::size_t> goodTracks =
m_duplicateClassifier.solveAmbiguity(cluster, tracks);
// Prepare the output track collection from the IDs
auto outputTracks = prepareOutputTrack(tracks, goodTracks);
m_outputTracks(ctx, std::move(outputTracks));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ ActsExamples::AmbiguityResolutionMLDBScanAlgorithm::execute(
// Read input data
const auto& tracks = m_inputTracks(ctx);
// Associate measurement to their respective tracks
std::multimap<int, std::pair<int, std::vector<int>>> trackMap =
mapTrackHits(tracks, m_cfg.nMeasurementsMin);
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
trackMap = mapTrackHits(tracks, m_cfg.nMeasurementsMin);
// Cluster the tracks using DBscan
auto cluster = Acts::dbscanTrackClustering(
trackMap, tracks, m_cfg.epsilonDBScan, m_cfg.minPointsDBScan);
// Select the ID of the track we want to keep
std::vector<int> goodTracks =
m_duplicateClassifier.solveAmbuguity(cluster, tracks);
std::vector<std::size_t> goodTracks =
m_duplicateClassifier.solveAmbiguity(cluster, tracks);
// Prepare the output track collection from the IDs
auto outputTracks = prepareOutputTrack(tracks, goodTracks);
m_outputTracks(ctx, std::move(outputTracks));
Expand Down
102 changes: 102 additions & 0 deletions Examples/Algorithms/TrackFindingML/src/SeedFilterMLAlgorithm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "ActsExamples/TrackFindingML/SeedFilterMLAlgorithm.hpp"

#include "Acts/Plugins/Mlpack/SeedFilterDBScanClustering.hpp"
#include "ActsExamples/Framework/ProcessCode.hpp"
#include "ActsExamples/Framework/WhiteBoard.hpp"

#include <iterator>
#include <map>

ActsExamples::SeedFilterMLAlgorithm::SeedFilterMLAlgorithm(
ActsExamples::SeedFilterMLAlgorithm::Config cfg, Acts::Logging::Level lvl)
: ActsExamples::IAlgorithm("SeedFilterMLAlgorithm", lvl),
m_cfg(std::move(cfg)),
m_seedClassifier(m_cfg.inputSeedFilterNN.c_str()) {
if (m_cfg.inputTrackParameters.empty()) {
throw std::invalid_argument("Missing track parameters input collection");
}
if (m_cfg.inputSimSeeds.empty()) {
throw std::invalid_argument("Missing seed input collection");
}
if (m_cfg.outputTrackParameters.empty()) {
throw std::invalid_argument("Missing track parameters output collection");
}
if (m_cfg.outputSimSeeds.empty()) {
throw std::invalid_argument("Missing seed output collection");
}
m_inputTrackParameters.initialize(m_cfg.inputTrackParameters);
m_inputSimSeeds.initialize(m_cfg.inputSimSeeds);
m_outputTrackParameters.initialize(m_cfg.outputTrackParameters);
m_outputSimSeeds.initialize(m_cfg.outputSimSeeds);
}

ActsExamples::ProcessCode ActsExamples::SeedFilterMLAlgorithm::execute(
const AlgorithmContext& ctx) const {
// Read input data
const auto& seeds = m_inputSimSeeds(ctx);
const auto& params = m_inputTrackParameters(ctx);
if (seeds.size() != params.size()) {
throw std::invalid_argument(
"The number of seeds and track parameters is different");
}

Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
networkInput(seeds.size(), 14);
std::vector<std::vector<double>> clusteringParams;
// Loop over the seed and parameters to fill the input for the clustering
// and the NN
for (std::size_t i = 0; i < seeds.size(); i++) {
// Compute the track parameters
double pT = std::abs(1.0 / params[i].parameters()[Acts::eBoundQOverP]) *
std::sin(params[i].parameters()[Acts::eBoundTheta]);
double eta =
std::atanh(std::cos(params[i].parameters()[Acts::eBoundTheta]));
double phi = params[i].parameters()[Acts::eBoundPhi];

// Fill and weight the clustering inputs
clusteringParams.push_back(
{phi / m_cfg.clusteringWeighPhi, eta / m_cfg.clusteringWeighEta,
seeds[i].z() / m_cfg.clusteringWeighZ, pT / m_cfg.clusteringWeighPt});

// Fill the NN input
networkInput.row(i) << pT, eta, phi, seeds[i].sp()[0]->x(),
seeds[i].sp()[0]->y(), seeds[i].sp()[0]->z(), seeds[i].sp()[1]->x(),
seeds[i].sp()[1]->y(), seeds[i].sp()[1]->z(), seeds[i].sp()[2]->x(),
seeds[i].sp()[2]->y(), seeds[i].sp()[2]->z(), seeds[i].z(),
seeds[i].seedQuality();
}

// Cluster the tracks using DBscan
auto cluster = Acts::dbscanSeedClustering(
clusteringParams, m_cfg.epsilonDBScan, m_cfg.minPointsDBScan);

// Select the ID of the track we want to keep
std::vector<std::size_t> goodSeed = m_seedClassifier.solveAmbiguity(
cluster, networkInput, m_cfg.minSeedScore);

// Create the output seed collection
SimSeedContainer outputSeeds;
outputSeeds.reserve(goodSeed.size());

// Create the output track parameters collection
TrackParametersContainer outputTrackParameters;
outputTrackParameters.reserve(goodSeed.size());

for (auto i : goodSeed) {
outputSeeds.push_back(seeds[i]);
outputTrackParameters.push_back(params[i]);
}

m_outputSimSeeds(ctx, SimSeedContainer{outputSeeds});
m_outputTrackParameters(ctx, TrackParametersContainer{outputTrackParameters});

return ActsExamples::ProcessCode::SUCCESS;
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class CsvSeedWriter : public WriterT<TrackParametersContainer> {
/// @brief Struct for brief seed summary info
///
struct SeedInfo {
std::size_t seedId = 0;
std::size_t seedID = 0;
ActsFatras::Barcode particleId;
float seedPt = -1;
float seedPhi = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CsvTrackWriter : public WriterT<ConstTrackContainer> {
///
struct TrackInfo : public Acts::MultiTrajectoryHelpers::TrajectoryState {
std::size_t trackId = 0;
unsigned int seedId = 0;
unsigned int seedID = 0;
ActsFatras::Barcode particleId;
std::size_t nMajorityHits = 0;
std::string trackType;
Expand Down
6 changes: 3 additions & 3 deletions Examples/Io/Csv/src/CsvSeedWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ ActsExamples::ProcessCode ActsExamples::CsvSeedWriter::writeT(

// track info
SeedInfo toAdd;
toAdd.seedId = iparams;
toAdd.seedID = iparams;
toAdd.particleId = majorityParticleId;
toAdd.seedPt = std::abs(1.0 / params[Acts::eBoundQOverP]) *
std::sin(params[Acts::eBoundTheta]);
Expand All @@ -160,7 +160,7 @@ ActsExamples::ProcessCode ActsExamples::CsvSeedWriter::writeT(
toAdd.seedType = truthMatched ? "duplicate" : "fake";
toAdd.measurementsID = ptrack;

infoMap[toAdd.seedId] = toAdd;
infoMap[toAdd.seedID] = toAdd;
}

mos << "seed_id,particleId,"
Expand All @@ -177,7 +177,7 @@ ActsExamples::ProcessCode ActsExamples::CsvSeedWriter::writeT(
info.seedType = "good";
}
// write the track info
mos << info.seedId << ",";
mos << info.seedID << ",";
mos << info.particleId << ",";
mos << info.seedPt << ",";
mos << info.seedEta << ",";
Expand Down
Loading
Loading