Skip to content

Commit

Permalink
feat: Graph export for Exa.TrkX (#2730)
Browse files Browse the repository at this point in the history
Allows to export the Graph after the GNN to a csv file. Needed for the GNN+CKF workflow
  • Loading branch information
benjaminhuth authored Nov 28, 2023
1 parent 7e9c2f0 commit 8f4162f
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "Acts/Definitions/Units.hpp"
#include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
#include "Acts/Plugins/ExaTrkX/Stages.hpp"
#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/SimHit.hpp"
Expand Down Expand Up @@ -52,6 +53,9 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
/// Output protoTracks collection.
std::string outputProtoTracks;

/// Output graph (optional)
std::string outputGraph;

std::shared_ptr<Acts::GraphConstructionBase> graphConstructor;

std::vector<std::shared_ptr<Acts::EdgeClassificationBase>> edgeClassifiers;
Expand All @@ -67,6 +71,9 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
float clusterXScale = 1.f;
float clusterYScale = 1.f;

/// Remove track candidates with 2 or less hits
bool filterShortTracks = false;

/// Target graph properties
std::size_t targetMinHits = 3;
double targetMinPT = 500 * Acts::UnitConstants::MeV;
Expand Down Expand Up @@ -114,6 +121,8 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {

WriteDataHandle<ProtoTrackContainer> m_outputProtoTracks{this,
"OutputProtoTracks"};
WriteDataHandle<Acts::TorchGraphStoreHook::Graph> m_outputGraph{
this, "OutputGraph"};

// for truth graph
ReadDataHandle<SimHitContainer> m_inputSimHits{this, "InputSimHits"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"

#include "Acts/Definitions/Units.hpp"
#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
#include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
#include "Acts/Utilities/Zip.hpp"
#include "ActsExamples/EventData/Index.hpp"
Expand All @@ -31,6 +32,7 @@ class ExamplesEdmHook : public Acts::ExaTrkXHook {
std::unique_ptr<const Acts::Logger> m_logger;
std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_truthGraphHook;
std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_targetGraphHook;
std::unique_ptr<Acts::TorchGraphStoreHook> m_graphStoreHook;

const Acts::Logger& logger() const { return *m_logger; }

Expand Down Expand Up @@ -98,17 +100,22 @@ class ExamplesEdmHook : public Acts::ExaTrkXHook {
truthGraph, logger.clone());
m_targetGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
targetGraph, logger.clone());
m_graphStoreHook = std::make_unique<Acts::TorchGraphStoreHook>();
}

~ExamplesEdmHook() {}

void operator()(const std::any& nodes, const std::any& edges) const override {
auto storedGraph() const { return m_graphStoreHook->storedGraph(); }

void operator()(const std::any& nodes, const std::any& edges,
const std::any& weights) const override {
ACTS_INFO("Metrics for total graph:");
(*m_truthGraphHook)(nodes, edges);
(*m_truthGraphHook)(nodes, edges, weights);
ACTS_INFO("Metrics for target graph (pT > "
<< m_targetPT / Acts::UnitConstants::GeV
<< " GeV, nHits >= " << m_targetSize << "):");
(*m_targetGraphHook)(nodes, edges);
(*m_targetGraphHook)(nodes, edges, weights);
(*m_graphStoreHook)(nodes, edges, weights);
}
};

Expand Down Expand Up @@ -153,6 +160,8 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(
m_inputParticles.maybeInitialize(m_cfg.inputParticles);
m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap);

m_outputGraph.maybeInitialize(m_cfg.outputGraph);

// reserve space for timing
m_timing.classifierTimes.resize(
m_cfg.edgeClassifiers.size(),
Expand Down Expand Up @@ -267,15 +276,35 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
// Make the prototracks
std::vector<ProtoTrack> protoTracks;
protoTracks.reserve(trackCandidates.size());

int nShortTracks = 0;

for (auto& x : trackCandidates) {
if (m_cfg.filterShortTracks && x.size() < 3) {
nShortTracks++;
continue;
}

ProtoTrack onetrack;
onetrack.reserve(x.size());

std::copy(x.begin(), x.end(), std::back_inserter(onetrack));
protoTracks.push_back(std::move(onetrack));
}

ACTS_INFO("Removed " << nShortTracks << " with less then 3 hits");
ACTS_INFO("Created " << protoTracks.size() << " proto tracks");
m_outputProtoTracks(ctx, std::move(protoTracks));

if (auto dhook = dynamic_cast<ExamplesEdmHook*>(&*hook);
dhook && m_outputGraph.isInitialized()) {
auto graph = dhook->storedGraph();
std::transform(
graph.first.begin(), graph.first.end(), graph.first.begin(),
[&](const auto& a) -> int64_t { return spacepointIDs.at(a); });
m_outputGraph(ctx, std::move(graph));
}

return ActsExamples::ProcessCode::SUCCESS;
}

Expand Down
1 change: 1 addition & 0 deletions Examples/Io/Csv/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(
src/CsvTrackWriter.cpp
src/CsvProtoTrackWriter.cpp
src/CsvSpacePointWriter.cpp
src/CsvExaTrkXGraphWriter.cpp
src/CsvBFieldWriter.cpp)
target_include_directories(
ActsExamplesIoCsv
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// This file is part of the Acts project.
//
// Copyright (C) 2020 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Utilities/Logger.hpp"
#include "ActsExamples/Framework/ProcessCode.hpp"
#include "ActsExamples/Framework/WriterT.hpp"
#include "ActsExamples/Utilities/Paths.hpp"

#include <cstddef>
#include <limits>
#include <string>

namespace ActsExamples {
struct AlgorithmContext;

class CsvExaTrkXGraphWriter final
: public WriterT<std::pair<std::vector<int64_t>, std::vector<float>>> {
public:
struct Config {
/// Which simulated (truth) hits collection to use.
std::string inputGraph;
/// Where to place output files
std::string outputDir;
/// Output filename stem.
std::string outputStem = "exatrkx-graph";
};

/// Construct the cluster writer.
///
/// @param config is the configuration object
/// @param level is the logging level
CsvExaTrkXGraphWriter(const Config& config, Acts::Logging::Level level);

/// Readonly access to the config
const Config& config() const { return m_cfg; }

protected:
/// Type-specific write implementation.
///
/// @param[in] ctx is the algorithm context
/// @param[in] simHits are the simhits to be written
ProcessCode writeT(const AlgorithmContext& ctx,
const std::pair<std::vector<int64_t>, std::vector<float>>&
graph) override;

private:
Config m_cfg;
};

} // namespace ActsExamples
56 changes: 56 additions & 0 deletions Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// This file is part of the Acts project.
//
// Copyright (C) 2020 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp"

#include "Acts/Definitions/Algebra.hpp"
#include "Acts/Definitions/Common.hpp"
#include "Acts/Definitions/Units.hpp"
#include "ActsExamples/Framework/AlgorithmContext.hpp"
#include "ActsExamples/Utilities/Paths.hpp"
#include "ActsFatras/EventData/Barcode.hpp"

#include <stdexcept>
#include <vector>

#include <dfe/dfe_io_dsv.hpp>
#include <dfe/dfe_namedtuple.hpp>

struct GraphData {
int64_t edge0;
int64_t edge1;
float weight;
DFE_NAMEDTUPLE(GraphData, edge0, edge1, weight);
};

ActsExamples::CsvExaTrkXGraphWriter::CsvExaTrkXGraphWriter(
const ActsExamples::CsvExaTrkXGraphWriter::Config& config,
Acts::Logging::Level level)
: WriterT(config.inputGraph, "CsvExaTrkXGraphWriter", level),
m_cfg(config) {}

ActsExamples::ProcessCode ActsExamples::CsvExaTrkXGraphWriter::writeT(
const ActsExamples::AlgorithmContext& ctx,
const std::pair<std::vector<int64_t>, std::vector<float>>& graph) {
std::string path = perEventFilepath(
m_cfg.outputDir, m_cfg.outputStem + ".csv", ctx.eventNumber);

dfe::NamedTupleCsvWriter<GraphData> writer(path);

const auto& [edges, weights] = graph;

for (auto i = 0ul; i < weights.size(); ++i) {
GraphData edge{};
edge.edge0 = edges[2 * i];
edge.edge1 = edges[2 * i + 1];
edge.weight = weights[i];
writer.append(edge);
}

return ActsExamples::ProcessCode::SUCCESS;
}
7 changes: 4 additions & 3 deletions Examples/Python/src/ExaTrkXTrackFinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,10 @@ void addExaTrkXTrackFinding(Context &ctx) {
ActsExamples::TrackFindingAlgorithmExaTrkX, mex,
"TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits,
inputParticles, inputClusters, inputMeasurementSimhitsMap,
outputProtoTracks, graphConstructor, edgeClassifiers, trackBuilder,
rScale, phiScale, zScale, cellCountScale, cellSumScale, clusterXScale,
clusterYScale, targetMinHits, targetMinPT);
outputProtoTracks, outputGraph, graphConstructor, edgeClassifiers,
trackBuilder, rScale, phiScale, zScale, cellCountScale, cellSumScale,
clusterXScale, clusterYScale, filterShortTracks, targetMinHits,
targetMinPT);

{
auto cls =
Expand Down
5 changes: 5 additions & 0 deletions Examples/Python/src/Output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ActsExamples/Digitization/DigitizationConfig.hpp"
#include "ActsExamples/Framework/ProcessCode.hpp"
#include "ActsExamples/Io/Csv/CsvBFieldWriter.hpp"
#include "ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp"
#include "ActsExamples/Io/Csv/CsvMeasurementWriter.hpp"
#include "ActsExamples/Io/Csv/CsvParticleWriter.hpp"
#include "ActsExamples/Io/Csv/CsvPlanarClusterWriter.hpp"
Expand Down Expand Up @@ -412,5 +413,9 @@ void addOutput(Context& ctx) {
register_csv_bfield_writer_binding<Writer::CoordinateType::RZ, true>(w);
register_csv_bfield_writer_binding<Writer::CoordinateType::RZ, false>(w);
}

ACTS_PYTHON_DECLARE_WRITER(ActsExamples::CsvExaTrkXGraphWriter, mex,
"CsvExaTrkXGraphWriter", inputGraph, outputDir,
outputStem);
}
} // namespace Acts::Python
1 change: 1 addition & 0 deletions Plugins/ExaTrkX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ if(ACTS_EXATRKX_ENABLE_TORCH)
src/TorchMetricLearning.cpp
src/BoostTrackBuilding.cpp
src/TorchTruthGraphMetricsHook.cpp
src/TorchGraphStoreHook.cpp
)
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ struct ExaTrkXTiming {
class ExaTrkXHook {
public:
virtual ~ExaTrkXHook() {}
virtual void operator()(const std::any &, const std::any &) const {};
virtual void operator()(const std::any &nodes, const std::any &edges,
const std::any &weights) const {};
};

class ExaTrkXPipeline {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
#include "Acts/Plugins/ExaTrkX/detail/CantorEdge.hpp"
#include "Acts/Utilities/Logger.hpp"

namespace Acts {

class TorchGraphStoreHook : public ExaTrkXHook {
public:
using Graph = std::pair<std::vector<int64_t>, std::vector<float>>;

private:
std::unique_ptr<Graph> m_storedGraph;

public:
TorchGraphStoreHook();
~TorchGraphStoreHook() override {}

void operator()(const std::any &, const std::any &edges,
const std::any &weights) const override;

const Graph &storedGraph() const { return *m_storedGraph; }
};

} // namespace Acts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class TorchTruthGraphMetricsHook : public ExaTrkXHook {
std::unique_ptr<const Acts::Logger> l);
~TorchTruthGraphMetricsHook() override {}

void operator()(const std::any &, const std::any &edges) const override;
void operator()(const std::any &, const std::any &edges,
const std::any &) const override;
};

} // namespace Acts
4 changes: 2 additions & 2 deletions Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
timing->graphBuildingTime = t1 - t0;
}

hook(nodes, edges);
hook(nodes, edges, {});

std::any edge_weights;
timing->classifierTimes.clear();
Expand All @@ -63,7 +63,7 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
edges = std::move(newEdges);
edge_weights = std::move(newWeights);

hook(nodes, edges);
hook(nodes, edges, edge_weights);
}

t0 = std::chrono::high_resolution_clock::now();
Expand Down
33 changes: 33 additions & 0 deletions Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"

#include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"

#include <torch/torch.h>

Acts::TorchGraphStoreHook::TorchGraphStoreHook() {
m_storedGraph = std::make_unique<Graph>();
}

void Acts::TorchGraphStoreHook::operator()(const std::any&,
const std::any& edges,
const std::any& weights) const {
if (not weights.has_value()) {
return;
}

m_storedGraph->first = detail::tensor2DToVector<int64_t>(
std::any_cast<torch::Tensor>(edges).t());

auto cpuWeights = std::any_cast<torch::Tensor>(weights).to(torch::kCPU);
m_storedGraph->second =
std::vector<float>(cpuWeights.data_ptr<float>(),
cpuWeights.data_ptr<float>() + cpuWeights.numel());
}
Loading

0 comments on commit 8f4162f

Please sign in to comment.