Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Acts Seed-Finding Comparison Update, main branch (2024.11.13.) #771

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TRACCC library, part of the ACTS project (R&D line)
#
# (c) 2021-2023 CERN for the benefit of the ACTS project
# (c) 2021-2024 CERN for the benefit of the ACTS project
#
# Mozilla Public License Version 2.0

Expand All @@ -17,8 +17,7 @@ add_library( traccc_tests_common STATIC
"common/tests/kalman_fitting_telescope_test.hpp"
"common/tests/kalman_fitting_toy_detector_test.hpp"
"common/tests/kalman_fitting_wire_chamber_test.hpp"
"common/tests/kalman_fitting_test.cpp"
"common/tests/space_point.hpp" )
"common/tests/kalman_fitting_test.cpp" )
target_include_directories( traccc_tests_common
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/common )
target_link_libraries( traccc_tests_common
Expand Down
32 changes: 0 additions & 32 deletions tests/common/tests/space_point.hpp

This file was deleted.

159 changes: 47 additions & 112 deletions tests/cpu/compare_with_acts_seeding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

// tests
#include "tests/atlas_cuts.hpp"
#include "tests/space_point.hpp"

// acts
#include "Acts/EventData/Seed.hpp"
Expand All @@ -39,6 +38,7 @@
#include <gtest/gtest.h>

// System include(s).
#include <functional>
#include <limits>

// We need to define a 'SpacePointCollector' that will bridge
Expand All @@ -47,29 +47,28 @@
// need to implement some _impl function to instruct ACTS how to retrieve the
// required quantities

// The internal details of this class are totally up to the user
// The internal details of this class are largely up to the user
class SpacePointCollector {
public:
friend Acts::SpacePointContainer<SpacePointCollector,
Acts::detail::RefHolder>;
using ValueType = SpacePoint;
using ActsSpacePointContainer =
Acts::SpacePointContainer<SpacePointCollector, Acts::detail::RefHolder>;
friend ActsSpacePointContainer;
using ValueType = traccc::spacepoint;

explicit SpacePointCollector(
std::vector<const ValueType*>& externalCollection)
: m_storage(&externalCollection) {}

std::size_t size_impl() const { return storage().size(); }
float x_impl(std::size_t idx) const { return storage()[idx]->x(); }
float y_impl(std::size_t idx) const { return storage()[idx]->y(); }
float z_impl(std::size_t idx) const { return storage()[idx]->z(); }
float varianceR_impl(std::size_t idx) const {
return storage()[idx]->varianceR;
traccc::spacepoint_collection_types::host& spacepoints)
: m_storage(spacepoints) {}

std::size_t size_impl() const { return m_storage.get().size(); }
double x_impl(std::size_t idx) const { return m_storage.get()[idx].x(); }
double y_impl(std::size_t idx) const { return m_storage.get()[idx].y(); }
double z_impl(std::size_t idx) const { return m_storage.get()[idx].z(); }
double varianceR_impl(std::size_t) const { return 0.; }
double varianceZ_impl(std::size_t) const { return 0.; }

const ValueType& get_impl(std::size_t idx) const {
return m_storage.get()[idx];
}
float varianceZ_impl(std::size_t idx) const {
return storage()[idx]->varianceZ;
}

const ValueType& get_impl(std::size_t idx) const { return *storage()[idx]; }

std::any component_impl(Acts::HashedString key, std::size_t) const {
using namespace Acts::HashedStringLiteral;
Expand All @@ -86,42 +85,19 @@ class SpacePointCollector {
}

private:
const std::vector<const ValueType*>& storage() const { return *m_storage; }
std::vector<const ValueType*>& storage() { return *m_storage; }

std::vector<const ValueType*>* m_storage{nullptr};
std::reference_wrapper<traccc::spacepoint_collection_types::host> m_storage;
};

inline bool operator==(const SpacePoint* acts_sp,
const traccc::spacepoint& traccc_sp) {
if (abs(acts_sp->x() - traccc_sp.global[0]) < traccc::float_epsilon &&
abs(acts_sp->y() - traccc_sp.global[1]) < traccc::float_epsilon &&
abs(acts_sp->z() - traccc_sp.global[2]) < traccc::float_epsilon) {
return true;
}
return false;
}

inline bool operator==(const traccc::spacepoint& traccc_sp,
const SpacePoint* acts_sp) {
if (abs(acts_sp->x() - traccc_sp.global[0]) < traccc::float_epsilon &&
abs(acts_sp->y() - traccc_sp.global[1]) < traccc::float_epsilon &&
abs(acts_sp->z() - traccc_sp.global[2]) < traccc::float_epsilon) {
return true;
}
return false;
}

class CompareWithActsSeedingTests
: public ::testing::TestWithParam<
std::tuple<std::string, std::string, unsigned int>> {};

// This defines the local frame test suite
TEST_P(CompareWithActsSeedingTests, Run) {

std::string detector_file = std::get<0>(GetParam());
std::string hits_dir = std::get<1>(GetParam());
unsigned int event = std::get<2>(GetParam());
const std::string detector_file = std::get<0>(GetParam());
const std::string hits_dir = std::get<1>(GetParam());
const unsigned int event = std::get<2>(GetParam());

// Memory resource used by the EDM.
vecmem::host_memory_resource host_mr;
Expand Down Expand Up @@ -150,37 +126,13 @@ TEST_P(CompareWithActsSeedingTests, Run) {
--------------------------------*/

auto internal_spacepoints_per_event = sb(spacepoints_per_event);
auto seeds = sf(spacepoints_per_event, internal_spacepoints_per_event);
auto traccc_seeds =
sf(spacepoints_per_event, internal_spacepoints_per_event);

/*--------------------------------
ACTS seeding
--------------------------------*/

// copy traccc::spacepoint into SpacePoint
std::vector<const SpacePoint*> spVec;
for (auto& sp : spacepoints_per_event) {
SpacePoint* acts_sp =
new SpacePoint{static_cast<float>(sp.global[0]),
static_cast<float>(sp.global[1]),
static_cast<float>(sp.global[2]),
std::hypot(static_cast<float>(sp.global[0]),
static_cast<float>(sp.global[1])),
0,
0,
0};
spVec.push_back(acts_sp);
}

// spacepoint equality check
int n_sp_match = 0;
for (auto& sp : spacepoints_per_event) {
if (std::find(spVec.begin(), spVec.end(), sp) != spVec.end()) {
n_sp_match++;
}
}
EXPECT_EQ(spacepoints_per_event.size(), n_sp_match);
EXPECT_EQ(spVec.size(), n_sp_match);

// We need to do some operations on the space points before we can give them
// to the seeding Config
Acts::SpacePointContainerConfig spConfig;
Expand All @@ -189,14 +141,15 @@ TEST_P(CompareWithActsSeedingTests, Run) {
Acts::SpacePointContainerOptions spOptions;
spOptions.beamPos = {traccc_config.beamPos[0], traccc_config.beamPos[1]};

SpacePointCollector container(spVec);
Acts::SpacePointContainer<decltype(container), Acts::detail::RefHolder>
spContainer(spConfig, spOptions, container);
SpacePointCollector container(spacepoints_per_event);
SpacePointCollector::ActsSpacePointContainer spContainer(
spConfig, spOptions, container);
// The seeding will then iterate on spContainer, that is on space point
// proxies This also means we will create seed of proxies of space points

// Define some types
using spacepoint_t = typename decltype(spContainer)::SpacePointProxyType;
using spacepoint_t =
SpacePointCollector::ActsSpacePointContainer::SpacePointProxyType;
using grid_t = Acts::CylindricalSpacePointGrid<spacepoint_t>;
using binfinder_t = Acts::GridBinFinder<grid_t::DIM>;
using binnedgroup_t = Acts::CylindricalBinnedGroup<spacepoint_t>;
Expand Down Expand Up @@ -390,13 +343,13 @@ TEST_P(CompareWithActsSeedingTests, Run) {
seedfinder_t a(acts_config);

// We define the state and the seed container
std::vector<seed_t> seedVector;
static thread_local decltype(a)::SeedingState state;
std::vector<seed_t> acts_seeds;
seedfinder_t::SeedingState state;
state.spacePointMutableData.resize(spContainer.size());

// Run the seeding
for (const auto [bottom, middle, top] : spGroup) {
a.createSeedsForGroup(acts_options, state, spGroup.grid(), seedVector,
a.createSeedsForGroup(acts_options, state, spGroup.grid(), acts_seeds,
bottom, middle, top, rMiddleSPRange);
}

Expand All @@ -405,48 +358,30 @@ TEST_P(CompareWithActsSeedingTests, Run) {
// the externalSpacePoint() method

// Count the number of matching seeds
// and push_back seed into sorted_seedVector
std::vector<Acts::Seed<SpacePoint>> sorted_seedVector;
int n_seed_match = 0;
for (auto& seed : seeds) {
std::size_t n_matched_acts_seeds = 0u;
for (const auto& traccc_seed : traccc_seeds) {
// Try to find the same Acts seed.
auto it = std::find_if(
seedVector.begin(), seedVector.end(), [&](auto acts_seed) {
auto traccc_spB = spacepoints_per_event.at(seed.spB_link);
auto traccc_spM = spacepoints_per_event.at(seed.spM_link);
auto traccc_spT = spacepoints_per_event.at(seed.spT_link);

auto& triplets = acts_seed.sp();
const SpacePoint* acts_spB = &triplets[0]->externalSpacePoint();
const SpacePoint* acts_spM = &triplets[1]->externalSpacePoint();
const SpacePoint* acts_spT = &triplets[2]->externalSpacePoint();

if (acts_spB == traccc_spB && acts_spM == traccc_spM &&
acts_spT == traccc_spT) {
return true;
}

return false;
acts_seeds.begin(), acts_seeds.end(), [&](const auto& acts_seed) {
return ((traccc_seed.spB_link == acts_seed.sp()[0]->index()) &&
(traccc_seed.spM_link == acts_seed.sp()[1]->index()) &&
(traccc_seed.spT_link == acts_seed.sp()[2]->index()));
});

if (it != seedVector.end()) {
const auto& seed_proxies = *it;
sorted_seedVector.push_back(Acts::Seed<SpacePoint>(
seed_proxies.sp()[0]->externalSpacePoint(),
seed_proxies.sp()[1]->externalSpacePoint(),
seed_proxies.sp()[2]->externalSpacePoint()));
n_seed_match++;
if (it != acts_seeds.end()) {
++n_matched_acts_seeds;
}
}

// Ensure that ACTS and traccc give the same result
// @TODO Uncomment the line below once acts-project/acts#2132 is merged
// EXPECT_EQ(seeds.size(), seedVector.size());
EXPECT_NEAR(static_cast<double>(seeds.size()),
static_cast<double>(sorted_seedVector.size()),
static_cast<double>(seeds.size()) * 0.0023);
EXPECT_GT(
static_cast<double>(n_seed_match) / static_cast<double>(seeds.size()),
0.9977);
EXPECT_NEAR(static_cast<double>(traccc_seeds.size()),
static_cast<double>(n_matched_acts_seeds),
static_cast<double>(traccc_seeds.size()) * 0.0023);
EXPECT_GT(static_cast<double>(n_matched_acts_seeds) /
static_cast<double>(traccc_seeds.size()),
0.9977);
}

INSTANTIATE_TEST_SUITE_P(
Expand Down
Loading