Skip to content

Commit

Permalink
feat: ML based seed filtering (#2709)
Browse files Browse the repository at this point in the history
This pull request add a ML based seed selection that can be used to reduce the number of seed after the seeding step.
This PR is pending on the merging of #2690.
  • Loading branch information
Corentin-Allaire authored Dec 14, 2023
1 parent 66c671c commit 61bc4e1
Show file tree
Hide file tree
Showing 32 changed files with 1,532 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ namespace detail {
///
/// @param trackMap : Multimap storing pair of track ID and vector of measurement ID. The keys are the number of measurement and are just there to facilitate the ordering.
/// @return an unordered map representing the clusters, the keys the ID of the primary track of each cluster and the store a vector of track IDs.
std::unordered_map<int, std::vector<int>> clusterDuplicateTracks(
const std::multimap<int, std::pair<int, std::vector<int>>>& trackMap);
std::unordered_map<std::size_t, std::vector<std::size_t>>
clusterDuplicateTracks(
const std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>&
trackMap);

} // namespace detail
} // namespace Acts
14 changes: 8 additions & 6 deletions Core/src/TrackFinding/AmbiguityTrackClustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@

#include <iterator>

std::unordered_map<int, std::vector<int>> Acts::detail::clusterDuplicateTracks(
const std::multimap<int, std::pair<int, std::vector<int>>>& trackMap) {
std::unordered_map<std::size_t, std::vector<std::size_t>>
Acts::detail::clusterDuplicateTracks(
const std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>&
trackMap) {
// Unordered map associating a vector with all the track ID of a cluster to
// the ID of the first track of the cluster
std::unordered_map<int, std::vector<int>> cluster;
std::unordered_map<std::size_t, std::vector<std::size_t>> cluster;
// Unordered map associating hits to the ID of the first track of the
// different clusters.
std::unordered_map<int, int> hitToTrack;
std::unordered_map<std::size_t, std::size_t> hitToTrack;

// Loop over all the tracks
for (auto track = trackMap.rbegin(); track != trackMap.rend(); ++track) {
std::vector<int> hits = track->second.second;
std::vector<std::size_t> hits = track->second.second;
auto matchedTrack = hitToTrack.end();
// Loop over all the hits in the track
for (auto hit = hits.begin(); hit != hits.end(); hit++) {
Expand All @@ -36,7 +38,7 @@ std::unordered_map<int, std::vector<int>> Acts::detail::clusterDuplicateTracks(
// None of the hits have been matched to a track create a new cluster
if (matchedTrack == hitToTrack.end()) {
cluster.emplace(track->second.first,
std::vector<int>(1, track->second.first));
std::vector<std::size_t>(1, track->second.first));
for (const auto& hit : hits) {
// Add the hits of the new cluster to the hitToTrack
hitToTrack.emplace(hit, track->second.first);
Expand Down
1 change: 1 addition & 0 deletions Examples/Algorithms/TrackFindingML/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(SOURCES
if(ACTS_BUILD_PLUGIN_MLPACK)
list(APPEND SOURCES
src/AmbiguityResolutionMLDBScanAlgorithm.cpp
src/SeedFilterMLAlgorithm.cpp
)
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ class AmbiguityResolutionML : public IAlgorithm {
/// @param tracks is the input track container
/// @param nMeasurementsMin minimum number of measurement per track
/// @return an ordered list containing pairs of track ID and associated measurement ID
std::multimap<int, std::pair<int, std::vector<int>>> mapTrackHits(
const ConstTrackContainer& tracks, int nMeasurementsMin) const;
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
mapTrackHits(const ConstTrackContainer& tracks, int nMeasurementsMin) const;

/// Prepare the output track container to be written
///
/// @param tracks is the input track container
/// @param goodTracks is list of the IDs of all the tracks we want to keep
ConstTrackContainer prepareOutputTrack(const ConstTrackContainer& tracks,
std::vector<int>& goodTracks) const;
ConstTrackContainer prepareOutputTrack(
const ConstTrackContainer& tracks,
std::vector<std::size_t>& goodTracks) const;
};

} // namespace ActsExamples
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Plugins/Onnx/SeedClassifier.hpp"
#include "ActsExamples/EventData/SimSeed.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/TrackFindingML/AmbiguityResolutionML.hpp"

#include <string>

namespace ActsExamples {

/// Removes seeds that seem to be duplicated and fake.
///
/// The implementation works as follows:
/// 1) Cluster together nearby seeds using a DBScan
/// 2) For each seed use a neural network to compute a score
/// 3) In each cluster keep the seed with the highest score
class SeedFilterMLAlgorithm : public IAlgorithm {
public:
struct Config {
/// Input estimated track parameters collection.
std::string inputTrackParameters;
/// Input seeds collection.
std::string inputSimSeeds;
/// Path to the ONNX model for the duplicate neural network
std::string inputSeedFilterNN;
/// Output estimated track parameters collection.
std::string outputTrackParameters;
/// Output seeds collection.
std::string outputSimSeeds;
/// Maximum distance between 2 tracks to be clustered in the DBScan
float epsilonDBScan = 0.03;
/// Minimum number of tracks to create a cluster in the DBScan
int minPointsDBScan = 2;
/// Minimum score a seed need to be selected
float minSeedScore = 0.1;
/// Clustering parameters weight for phi used before the DBSCAN
double clusteringWeighPhi = 1.0;
/// Clustering parameters weight for eta used before the DBSCAN
double clusteringWeighEta = 1.0;
/// Clustering parameters weight for z used before the DBSCAN
double clusteringWeighZ = 50.0;
/// Clustering parameters weight for pT used before the DBSCAN
double clusteringWeighPt = 1.0;
};

/// Construct the seed filter algorithm.
///
/// @param cfg is the algorithm configuration
/// @param lvl is the logging level
SeedFilterMLAlgorithm(Config cfg, Acts::Logging::Level lvl);

/// Run the seed filter algorithm.
///
/// @param cxt is the algorithm context with event information
/// @return a process code indication success or failure
ProcessCode execute(const AlgorithmContext& ctx) const final;

/// Const access to the config
const Config& config() const { return m_cfg; }

private:
Config m_cfg;
// ONNX model for track selection
Acts::SeedClassifier m_seedClassifier;
ReadDataHandle<TrackParametersContainer> m_inputTrackParameters{
this, "InputTrackParameters"};
ReadDataHandle<SimSeedContainer> m_inputSimSeeds{this, "InputSimSeeds"};
WriteDataHandle<TrackParametersContainer> m_outputTrackParameters{
this, "OutputTrackParameters"};
WriteDataHandle<SimSeedContainer> m_outputSimSeeds{this, "OutputSimSeeds"};
};

} // namespace ActsExamples
15 changes: 8 additions & 7 deletions Examples/Algorithms/TrackFindingML/src/AmbiguityResolutionML.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,24 @@ ActsExamples::AmbiguityResolutionML::AmbiguityResolutionML(
std::string name, Acts::Logging::Level lvl)
: ActsExamples::IAlgorithm(name, lvl) {}

std::multimap<int, std::pair<int, std::vector<int>>>
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
ActsExamples::AmbiguityResolutionML::mapTrackHits(
const ActsExamples::ConstTrackContainer& tracks,
int nMeasurementsMin) const {
std::multimap<int, std::pair<int, std::vector<int>>> trackMap;
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>> trackMap;
// Loop over all the trajectories in the events
for (const auto& track : tracks) {
std::vector<int> hits;
std::vector<std::size_t> hits;
int nbMeasurements = 0;
// Store the hits id for the trajectory and compute the number of
// measurement
tracks.trackStateContainer().visitBackwards(
track.tipIndex(), [&](const auto& state) {
if (state.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
int indexHit = state.getUncalibratedSourceLink()
.template get<ActsExamples::IndexSourceLink>()
.index();
std::size_t indexHit =
state.getUncalibratedSourceLink()
.template get<ActsExamples::IndexSourceLink>()
.index();
hits.emplace_back(indexHit);
++nbMeasurements;
}
Expand All @@ -47,7 +48,7 @@ ActsExamples::AmbiguityResolutionML::mapTrackHits(
ActsExamples::ConstTrackContainer
ActsExamples::AmbiguityResolutionML::prepareOutputTrack(
const ActsExamples::ConstTrackContainer& tracks,
std::vector<int>& goodTracks) const {
std::vector<std::size_t>& goodTracks) const {
std::shared_ptr<Acts::ConstVectorMultiTrajectory> trackStateContainer =
tracks.trackStateContainerHolder();
auto trackContainer = std::make_shared<Acts::VectorTrackContainer>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ ActsExamples::ProcessCode ActsExamples::AmbiguityResolutionMLAlgorithm::execute(
// Read input data
const auto& tracks = m_inputTracks(ctx);
// Associate measurement to their respective tracks
std::multimap<int, std::pair<int, std::vector<int>>> trackMap =
mapTrackHits(tracks, m_cfg.nMeasurementsMin);
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
trackMap = mapTrackHits(tracks, m_cfg.nMeasurementsMin);
auto cluster = Acts::detail::clusterDuplicateTracks(trackMap);
// Select the ID of the track we want to keep
std::vector<int> goodTracks =
m_duplicateClassifier.solveAmbuguity(cluster, tracks);
std::vector<std::size_t> goodTracks =
m_duplicateClassifier.solveAmbiguity(cluster, tracks);
// Prepare the output track collection from the IDs
auto outputTracks = prepareOutputTrack(tracks, goodTracks);
m_outputTracks(ctx, std::move(outputTracks));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ ActsExamples::AmbiguityResolutionMLDBScanAlgorithm::execute(
// Read input data
const auto& tracks = m_inputTracks(ctx);
// Associate measurement to their respective tracks
std::multimap<int, std::pair<int, std::vector<int>>> trackMap =
mapTrackHits(tracks, m_cfg.nMeasurementsMin);
std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
trackMap = mapTrackHits(tracks, m_cfg.nMeasurementsMin);
// Cluster the tracks using DBscan
auto cluster = Acts::dbscanTrackClustering(
trackMap, tracks, m_cfg.epsilonDBScan, m_cfg.minPointsDBScan);
// Select the ID of the track we want to keep
std::vector<int> goodTracks =
m_duplicateClassifier.solveAmbuguity(cluster, tracks);
std::vector<std::size_t> goodTracks =
m_duplicateClassifier.solveAmbiguity(cluster, tracks);
// Prepare the output track collection from the IDs
auto outputTracks = prepareOutputTrack(tracks, goodTracks);
m_outputTracks(ctx, std::move(outputTracks));
Expand Down
102 changes: 102 additions & 0 deletions Examples/Algorithms/TrackFindingML/src/SeedFilterMLAlgorithm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "ActsExamples/TrackFindingML/SeedFilterMLAlgorithm.hpp"

#include "Acts/Plugins/Mlpack/SeedFilterDBScanClustering.hpp"
#include "ActsExamples/Framework/ProcessCode.hpp"
#include "ActsExamples/Framework/WhiteBoard.hpp"

#include <iterator>
#include <map>

ActsExamples::SeedFilterMLAlgorithm::SeedFilterMLAlgorithm(
ActsExamples::SeedFilterMLAlgorithm::Config cfg, Acts::Logging::Level lvl)
: ActsExamples::IAlgorithm("SeedFilterMLAlgorithm", lvl),
m_cfg(std::move(cfg)),
m_seedClassifier(m_cfg.inputSeedFilterNN.c_str()) {
if (m_cfg.inputTrackParameters.empty()) {
throw std::invalid_argument("Missing track parameters input collection");
}
if (m_cfg.inputSimSeeds.empty()) {
throw std::invalid_argument("Missing seed input collection");
}
if (m_cfg.outputTrackParameters.empty()) {
throw std::invalid_argument("Missing track parameters output collection");
}
if (m_cfg.outputSimSeeds.empty()) {
throw std::invalid_argument("Missing seed output collection");
}
m_inputTrackParameters.initialize(m_cfg.inputTrackParameters);
m_inputSimSeeds.initialize(m_cfg.inputSimSeeds);
m_outputTrackParameters.initialize(m_cfg.outputTrackParameters);
m_outputSimSeeds.initialize(m_cfg.outputSimSeeds);
}

ActsExamples::ProcessCode ActsExamples::SeedFilterMLAlgorithm::execute(
const AlgorithmContext& ctx) const {
// Read input data
const auto& seeds = m_inputSimSeeds(ctx);
const auto& params = m_inputTrackParameters(ctx);
if (seeds.size() != params.size()) {
throw std::invalid_argument(
"The number of seeds and track parameters is different");
}

Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
networkInput(seeds.size(), 14);
std::vector<std::vector<double>> clusteringParams;
// Loop over the seed and parameters to fill the input for the clustering
// and the NN
for (std::size_t i = 0; i < seeds.size(); i++) {
// Compute the track parameters
double pT = std::abs(1.0 / params[i].parameters()[Acts::eBoundQOverP]) *
std::sin(params[i].parameters()[Acts::eBoundTheta]);
double eta =
std::atanh(std::cos(params[i].parameters()[Acts::eBoundTheta]));
double phi = params[i].parameters()[Acts::eBoundPhi];

// Fill and weight the clustering inputs
clusteringParams.push_back(
{phi / m_cfg.clusteringWeighPhi, eta / m_cfg.clusteringWeighEta,
seeds[i].z() / m_cfg.clusteringWeighZ, pT / m_cfg.clusteringWeighPt});

// Fill the NN input
networkInput.row(i) << pT, eta, phi, seeds[i].sp()[0]->x(),
seeds[i].sp()[0]->y(), seeds[i].sp()[0]->z(), seeds[i].sp()[1]->x(),
seeds[i].sp()[1]->y(), seeds[i].sp()[1]->z(), seeds[i].sp()[2]->x(),
seeds[i].sp()[2]->y(), seeds[i].sp()[2]->z(), seeds[i].z(),
seeds[i].seedQuality();
}

// Cluster the tracks using DBscan
auto cluster = Acts::dbscanSeedClustering(
clusteringParams, m_cfg.epsilonDBScan, m_cfg.minPointsDBScan);

// Select the ID of the track we want to keep
std::vector<std::size_t> goodSeed = m_seedClassifier.solveAmbiguity(
cluster, networkInput, m_cfg.minSeedScore);

// Create the output seed collection
SimSeedContainer outputSeeds;
outputSeeds.reserve(goodSeed.size());

// Create the output track parameters collection
TrackParametersContainer outputTrackParameters;
outputTrackParameters.reserve(goodSeed.size());

for (auto i : goodSeed) {
outputSeeds.push_back(seeds[i]);
outputTrackParameters.push_back(params[i]);
}

m_outputSimSeeds(ctx, SimSeedContainer{outputSeeds});
m_outputTrackParameters(ctx, TrackParametersContainer{outputTrackParameters});

return ActsExamples::ProcessCode::SUCCESS;
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class CsvSeedWriter : public WriterT<TrackParametersContainer> {
/// @brief Struct for brief seed summary info
///
struct SeedInfo {
std::size_t seedId = 0;
std::size_t seedID = 0;
ActsFatras::Barcode particleId;
float seedPt = -1;
float seedPhi = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CsvTrackWriter : public WriterT<ConstTrackContainer> {
///
struct TrackInfo : public Acts::MultiTrajectoryHelpers::TrajectoryState {
std::size_t trackId = 0;
unsigned int seedId = 0;
unsigned int seedID = 0;
ActsFatras::Barcode particleId;
std::size_t nMajorityHits = 0;
std::string trackType;
Expand Down
6 changes: 3 additions & 3 deletions Examples/Io/Csv/src/CsvSeedWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ ActsExamples::ProcessCode ActsExamples::CsvSeedWriter::writeT(

// track info
SeedInfo toAdd;
toAdd.seedId = iparams;
toAdd.seedID = iparams;
toAdd.particleId = majorityParticleId;
toAdd.seedPt = std::abs(1.0 / params[Acts::eBoundQOverP]) *
std::sin(params[Acts::eBoundTheta]);
Expand All @@ -160,7 +160,7 @@ ActsExamples::ProcessCode ActsExamples::CsvSeedWriter::writeT(
toAdd.seedType = truthMatched ? "duplicate" : "fake";
toAdd.measurementsID = ptrack;

infoMap[toAdd.seedId] = toAdd;
infoMap[toAdd.seedID] = toAdd;
}

mos << "seed_id,particleId,"
Expand All @@ -177,7 +177,7 @@ ActsExamples::ProcessCode ActsExamples::CsvSeedWriter::writeT(
info.seedType = "good";
}
// write the track info
mos << info.seedId << ",";
mos << info.seedID << ",";
mos << info.particleId << ",";
mos << info.seedPt << ",";
mos << info.seedEta << ",";
Expand Down
Loading

0 comments on commit 61bc4e1

Please sign in to comment.