Skip to content

Commit

Permalink
Made Acts use traccc::spacepoint directly.
Browse files Browse the repository at this point in the history
With the latest Acts version it's not necessary to create a
dummy SpacePoint class anymore. We can put a wrapper around
traccc::spacepoint directly.

At the same time simplified / cleaned the code a little.

Thanks to Carlo for his suggestions!
  • Loading branch information
krasznaa committed Nov 13, 2024
1 parent b1f8b5c commit a12397e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 147 deletions.
5 changes: 2 additions & 3 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TRACCC library, part of the ACTS project (R&D line)
#
# (c) 2021-2023 CERN for the benefit of the ACTS project
# (c) 2021-2024 CERN for the benefit of the ACTS project
#
# Mozilla Public License Version 2.0

Expand All @@ -17,8 +17,7 @@ add_library( traccc_tests_common STATIC
"common/tests/kalman_fitting_telescope_test.hpp"
"common/tests/kalman_fitting_toy_detector_test.hpp"
"common/tests/kalman_fitting_wire_chamber_test.hpp"
"common/tests/kalman_fitting_test.cpp"
"common/tests/space_point.hpp" )
"common/tests/kalman_fitting_test.cpp" )
target_include_directories( traccc_tests_common
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/common )
target_link_libraries( traccc_tests_common
Expand Down
32 changes: 0 additions & 32 deletions tests/common/tests/space_point.hpp

This file was deleted.

159 changes: 47 additions & 112 deletions tests/cpu/compare_with_acts_seeding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

// tests
#include "tests/atlas_cuts.hpp"
#include "tests/space_point.hpp"

// acts
#include "Acts/EventData/Seed.hpp"
Expand All @@ -39,6 +38,7 @@
#include <gtest/gtest.h>

// System include(s).
#include <functional>
#include <limits>

// We need to define a 'SpacePointCollector' that will bridge
Expand All @@ -47,29 +47,28 @@
// need to implement some _impl function to instruct ACTS how to retrieve the
// required quantities

// The internal details of this class are totally up to the user
// The internal details of this class are largely up to the user
class SpacePointCollector {
public:
friend Acts::SpacePointContainer<SpacePointCollector,
Acts::detail::RefHolder>;
using ValueType = SpacePoint;
using ActsSpacePointContainer =
Acts::SpacePointContainer<SpacePointCollector, Acts::detail::RefHolder>;
friend ActsSpacePointContainer;
using ValueType = traccc::spacepoint;

explicit SpacePointCollector(
std::vector<const ValueType*>& externalCollection)
: m_storage(&externalCollection) {}

std::size_t size_impl() const { return storage().size(); }
float x_impl(std::size_t idx) const { return storage()[idx]->x(); }
float y_impl(std::size_t idx) const { return storage()[idx]->y(); }
float z_impl(std::size_t idx) const { return storage()[idx]->z(); }
float varianceR_impl(std::size_t idx) const {
return storage()[idx]->varianceR;
traccc::spacepoint_collection_types::host& spacepoints)
: m_storage(spacepoints) {}

std::size_t size_impl() const { return m_storage.get().size(); }
double x_impl(std::size_t idx) const { return m_storage.get()[idx].x(); }
double y_impl(std::size_t idx) const { return m_storage.get()[idx].y(); }
double z_impl(std::size_t idx) const { return m_storage.get()[idx].z(); }
double varianceR_impl(std::size_t) const { return 0.; }
double varianceZ_impl(std::size_t) const { return 0.; }

const ValueType& get_impl(std::size_t idx) const {
return m_storage.get()[idx];
}
float varianceZ_impl(std::size_t idx) const {
return storage()[idx]->varianceZ;
}

const ValueType& get_impl(std::size_t idx) const { return *storage()[idx]; }

std::any component_impl(Acts::HashedString key, std::size_t) const {
using namespace Acts::HashedStringLiteral;
Expand All @@ -86,42 +85,19 @@ class SpacePointCollector {
}

private:
const std::vector<const ValueType*>& storage() const { return *m_storage; }
std::vector<const ValueType*>& storage() { return *m_storage; }

std::vector<const ValueType*>* m_storage{nullptr};
std::reference_wrapper<traccc::spacepoint_collection_types::host> m_storage;
};

inline bool operator==(const SpacePoint* acts_sp,
const traccc::spacepoint& traccc_sp) {
if (abs(acts_sp->x() - traccc_sp.global[0]) < traccc::float_epsilon &&
abs(acts_sp->y() - traccc_sp.global[1]) < traccc::float_epsilon &&
abs(acts_sp->z() - traccc_sp.global[2]) < traccc::float_epsilon) {
return true;
}
return false;
}

inline bool operator==(const traccc::spacepoint& traccc_sp,
const SpacePoint* acts_sp) {
if (abs(acts_sp->x() - traccc_sp.global[0]) < traccc::float_epsilon &&
abs(acts_sp->y() - traccc_sp.global[1]) < traccc::float_epsilon &&
abs(acts_sp->z() - traccc_sp.global[2]) < traccc::float_epsilon) {
return true;
}
return false;
}

class CompareWithActsSeedingTests
: public ::testing::TestWithParam<
std::tuple<std::string, std::string, unsigned int>> {};

// This defines the local frame test suite
TEST_P(CompareWithActsSeedingTests, Run) {

std::string detector_file = std::get<0>(GetParam());
std::string hits_dir = std::get<1>(GetParam());
unsigned int event = std::get<2>(GetParam());
const std::string detector_file = std::get<0>(GetParam());
const std::string hits_dir = std::get<1>(GetParam());
const unsigned int event = std::get<2>(GetParam());

// Memory resource used by the EDM.
vecmem::host_memory_resource host_mr;
Expand Down Expand Up @@ -150,37 +126,13 @@ TEST_P(CompareWithActsSeedingTests, Run) {
--------------------------------*/

auto internal_spacepoints_per_event = sb(spacepoints_per_event);
auto seeds = sf(spacepoints_per_event, internal_spacepoints_per_event);
auto traccc_seeds =
sf(spacepoints_per_event, internal_spacepoints_per_event);

/*--------------------------------
ACTS seeding
--------------------------------*/

// copy traccc::spacepoint into SpacePoint
std::vector<const SpacePoint*> spVec;
for (auto& sp : spacepoints_per_event) {
SpacePoint* acts_sp =
new SpacePoint{static_cast<float>(sp.global[0]),
static_cast<float>(sp.global[1]),
static_cast<float>(sp.global[2]),
std::hypot(static_cast<float>(sp.global[0]),
static_cast<float>(sp.global[1])),
0,
0,
0};
spVec.push_back(acts_sp);
}

// spacepoint equality check
int n_sp_match = 0;
for (auto& sp : spacepoints_per_event) {
if (std::find(spVec.begin(), spVec.end(), sp) != spVec.end()) {
n_sp_match++;
}
}
EXPECT_EQ(spacepoints_per_event.size(), n_sp_match);
EXPECT_EQ(spVec.size(), n_sp_match);

// We need to do some operations on the space points before we can give them
// to the seeding Config
Acts::SpacePointContainerConfig spConfig;
Expand All @@ -189,14 +141,15 @@ TEST_P(CompareWithActsSeedingTests, Run) {
Acts::SpacePointContainerOptions spOptions;
spOptions.beamPos = {traccc_config.beamPos[0], traccc_config.beamPos[1]};

SpacePointCollector container(spVec);
Acts::SpacePointContainer<decltype(container), Acts::detail::RefHolder>
spContainer(spConfig, spOptions, container);
SpacePointCollector container(spacepoints_per_event);
SpacePointCollector::ActsSpacePointContainer spContainer(
spConfig, spOptions, container);
// The seeding will then iterate on spContainer, that is on space point
// proxies This also means we will create seed of proxies of space points

// Define some types
using spacepoint_t = typename decltype(spContainer)::SpacePointProxyType;
using spacepoint_t =
SpacePointCollector::ActsSpacePointContainer::SpacePointProxyType;
using grid_t = Acts::CylindricalSpacePointGrid<spacepoint_t>;
using binfinder_t = Acts::GridBinFinder<grid_t::DIM>;
using binnedgroup_t = Acts::CylindricalBinnedGroup<spacepoint_t>;
Expand Down Expand Up @@ -390,13 +343,13 @@ TEST_P(CompareWithActsSeedingTests, Run) {
seedfinder_t a(acts_config);

// We define the state and the seed container
std::vector<seed_t> seedVector;
static thread_local decltype(a)::SeedingState state;
std::vector<seed_t> acts_seeds;
seedfinder_t::SeedingState state;
state.spacePointMutableData.resize(spContainer.size());

// Run the seeding
for (const auto [bottom, middle, top] : spGroup) {
a.createSeedsForGroup(acts_options, state, spGroup.grid(), seedVector,
a.createSeedsForGroup(acts_options, state, spGroup.grid(), acts_seeds,
bottom, middle, top, rMiddleSPRange);
}

Expand All @@ -405,48 +358,30 @@ TEST_P(CompareWithActsSeedingTests, Run) {
// the externalSpacePoint() method

// Count the number of matching seeds
// and push_back seed into sorted_seedVector
std::vector<Acts::Seed<SpacePoint>> sorted_seedVector;
int n_seed_match = 0;
for (auto& seed : seeds) {
std::size_t n_matched_acts_seeds = 0u;
for (const auto& traccc_seed : traccc_seeds) {
// Try to find the same Acts seed.
auto it = std::find_if(
seedVector.begin(), seedVector.end(), [&](auto acts_seed) {
auto traccc_spB = spacepoints_per_event.at(seed.spB_link);
auto traccc_spM = spacepoints_per_event.at(seed.spM_link);
auto traccc_spT = spacepoints_per_event.at(seed.spT_link);

auto& triplets = acts_seed.sp();
const SpacePoint* acts_spB = &triplets[0]->externalSpacePoint();
const SpacePoint* acts_spM = &triplets[1]->externalSpacePoint();
const SpacePoint* acts_spT = &triplets[2]->externalSpacePoint();

if (acts_spB == traccc_spB && acts_spM == traccc_spM &&
acts_spT == traccc_spT) {
return true;
}

return false;
acts_seeds.begin(), acts_seeds.end(), [&](const auto& acts_seed) {
return ((traccc_seed.spB_link == acts_seed.sp()[0]->index()) &&
(traccc_seed.spM_link == acts_seed.sp()[1]->index()) &&
(traccc_seed.spT_link == acts_seed.sp()[2]->index()));
});

if (it != seedVector.end()) {
const auto& seed_proxies = *it;
sorted_seedVector.push_back(Acts::Seed<SpacePoint>(
seed_proxies.sp()[0]->externalSpacePoint(),
seed_proxies.sp()[1]->externalSpacePoint(),
seed_proxies.sp()[2]->externalSpacePoint()));
n_seed_match++;
if (it != acts_seeds.end()) {
++n_matched_acts_seeds;
}
}

// Ensure that ACTS and traccc give the same result
// @TODO Uncomment the line below once acts-project/acts#2132 is merged
// EXPECT_EQ(seeds.size(), seedVector.size());
EXPECT_NEAR(static_cast<double>(seeds.size()),
static_cast<double>(sorted_seedVector.size()),
static_cast<double>(seeds.size()) * 0.0023);
EXPECT_GT(
static_cast<double>(n_seed_match) / static_cast<double>(seeds.size()),
0.9977);
EXPECT_NEAR(static_cast<double>(traccc_seeds.size()),
static_cast<double>(n_matched_acts_seeds),
static_cast<double>(traccc_seeds.size()) * 0.0023);
EXPECT_GT(static_cast<double>(n_matched_acts_seeds) /
static_cast<double>(traccc_seeds.size()),
0.9977);
}

INSTANTIATE_TEST_SUITE_P(
Expand Down

0 comments on commit a12397e

Please sign in to comment.