From 1b5fcf6776486f19834ee9709b34e8079884c371 Mon Sep 17 00:00:00 2001 From: Juan Miguel Carceller <22276694+jmcarcell@users.noreply.github.com> Date: Tue, 11 Jul 2023 12:18:44 +0200 Subject: [PATCH] Add support for the new RNTuple format (#395) * Add a RNTuple writer * Cleanup and add a reader * Add compilation instructions for RNTuple * Add tests * Fix the reader and writer so that they pass most of the tests * Commit missing changes in the header * Add support for Generic Parameters * Add an ugly workaround to the unique_ptr issue * Read also vector members and remove some comments * Do a bit of cleanup * Do more cleanup, also compiler warnings * Add names in rootUtils.h, fix a few compiler warnings * Add a few minor changes * Add missing changes in the headers * Change map -> unordered_map and use append in CMakeLists.txt * Simplify writing and reading of generic parameters * Only create the ID table once * Add CollectionInfo structs * Add a ROOT version check * Add missing endif() * Add Name at the end of some names * Add missing Name at the end * Cast to rvalue * Cache entries and reserve * Add comment and remove old comments * Remove a few couts * Remove intermediate variables and use std::move * Run clang-format * Use clang-format on tests too * Enable RNTuple I/O in Key4hep CI * Check if dev3 workflows come with recent enough ROOT * Change MakeField to the new signature * Update the RNTuple reader and writer to use the buffer factory * Run clang-format * Update the RNTuple writer to use a bare model * Add friends for Generic Parameters * Update changes after the changes in the collectionID and string_view * Run clang-format * Update the reader and writer to conform to #405 * Reorganize and clean up code in the reader * Run clang-format * Simplify how the references are filled --------- Co-authored-by: jmcarcell Co-authored-by: tmadlener --- .github/workflows/key4hep.yml | 3 +- .github/workflows/ubuntu.yml | 7 +- CMakeLists.txt | 10 +- include/podio/CollectionBuffers.h | 1 + include/podio/GenericParameters.h | 8 + include/podio/ROOTNTupleReader.h | 104 +++++++++ include/podio/ROOTNTupleWriter.h | 74 ++++++ include/podio/UserDataCollection.h | 2 +- python/templates/CollectionData.cc.jinja2 | 1 + src/CMakeLists.txt | 18 ++ src/ROOTNTupleReader.cc | 186 +++++++++++++++ src/ROOTNTupleWriter.cc | 263 ++++++++++++++++++++++ src/rootUtils.h | 69 ++++++ tests/root_io/CMakeLists.txt | 10 + tests/root_io/read_rntuple.cpp | 6 + tests/root_io/write_rntuple.cpp | 6 + 16 files changed, 762 insertions(+), 6 deletions(-) create mode 100644 include/podio/ROOTNTupleReader.h create mode 100644 include/podio/ROOTNTupleWriter.h create mode 100644 src/ROOTNTupleReader.cc create mode 100644 src/ROOTNTupleWriter.cc create mode 100644 tests/root_io/read_rntuple.cpp create mode 100644 tests/root_io/write_rntuple.cpp diff --git a/.github/workflows/key4hep.yml b/.github/workflows/key4hep.yml index 92240b93e..43f69ef67 100644 --- a/.github/workflows/key4hep.yml +++ b/.github/workflows/key4hep.yml @@ -30,7 +30,8 @@ jobs: -DCMAKE_INSTALL_PREFIX=../install \ -DCMAKE_CXX_STANDARD=17 \ -DCMAKE_CXX_FLAGS=" -fdiagnostics-color=always -Werror -Wno-error=deprecated-declarations " \ - -DUSE_EXTERNAL_CATCH2=AUTO \ + -DUSE_EXTERNAL_CATCH2=ON \ + -DENABLE_RNTUPLE=ON \ -G Ninja .. echo "::endgroup::" echo "::group::Build" diff --git a/.github/workflows/ubuntu.yml b/.github/workflows/ubuntu.yml index 16ae05bee..9814ca7cb 100644 --- a/.github/workflows/ubuntu.yml +++ b/.github/workflows/ubuntu.yml @@ -30,9 +30,10 @@ jobs: -DCMAKE_INSTALL_PREFIX=../install \ -DCMAKE_CXX_STANDARD=17 \ -DCMAKE_CXX_FLAGS=" -fdiagnostics-color=always -Werror -Wno-error=deprecated-declarations " \ - -DUSE_EXTERNAL_CATCH2=OFF \ - -DPODIO_SET_RPATH=ON \ - -G Ninja .. + -DUSE_EXTERNAL_CATCH2=OFF \ + -DPODIO_SET_RPATH=ON \ + -DENABLE_RNTUPLE=ON \ + -G Ninja .. echo "::endgroup::" echo "::group::Build" ninja -k0 diff --git a/CMakeLists.txt b/CMakeLists.txt index de56ea75a..35c521882 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,12 +68,20 @@ ADD_CLANG_TIDY() option(CREATE_DOC "Whether or not to create doxygen doc target." OFF) option(ENABLE_SIO "Build SIO I/O support" OFF) option(PODIO_RELAX_PYVER "Do not require exact python version match with ROOT" OFF) +option(ENABLE_RNTUPLE "Build with support for the new ROOT NTtuple format" OFF) #--- Declare ROOT dependency --------------------------------------------------- list(APPEND CMAKE_PREFIX_PATH $ENV{ROOTSYS}) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) -find_package(ROOT REQUIRED COMPONENTS RIO Tree) +if(NOT ENABLE_RNTUPLE) + find_package(ROOT REQUIRED COMPONENTS RIO Tree) +else() + find_package(ROOT REQUIRED COMPONENTS RIO Tree ROOTNTuple) + if(${ROOT_VERSION} VERSION_LESS 6.28.02) + message(FATAL_ERROR "You are trying to build podio with support for the new ROOT NTuple format, but your ROOT version is too old. Please update ROOT to at least version 6.28.02") + endif() +endif() # Check that root is compiled with a modern enough c++ standard get_target_property(ROOT_COMPILE_FEATURES ROOT::Core INTERFACE_COMPILE_FEATURES) diff --git a/include/podio/CollectionBuffers.h b/include/podio/CollectionBuffers.h index 37ee07fe4..b51161c2d 100644 --- a/include/podio/CollectionBuffers.h +++ b/include/podio/CollectionBuffers.h @@ -27,6 +27,7 @@ using VectorMembersInfo = std::vector>; */ struct CollectionWriteBuffers { void* data{nullptr}; + void* vecPtr{nullptr}; CollRefCollection* references{nullptr}; VectorMembersInfo* vectorMembers{nullptr}; diff --git a/include/podio/GenericParameters.h b/include/podio/GenericParameters.h index 83eaef8f7..a5d7dde76 100644 --- a/include/podio/GenericParameters.h +++ b/include/podio/GenericParameters.h @@ -18,6 +18,11 @@ class write_device; using version_type = uint32_t; // from sio/definitions } // namespace sio +namespace podio { +class ROOTNTupleReader; +class ROOTNTupleWriter; +} // namespace podio + #define DEPR_NON_TEMPLATE \ [[deprecated("Non-templated access will be removed. Switch to templated access functionality")]] @@ -145,6 +150,8 @@ class GenericParameters { friend void writeGenericParameters(sio::write_device& device, const GenericParameters& parameters); friend void readGenericParameters(sio::read_device& device, GenericParameters& parameters, sio::version_type version); + friend ROOTNTupleReader; + friend ROOTNTupleWriter; /// Get a reference to the internal map for a given type template @@ -187,6 +194,7 @@ class GenericParameters { } } +private: /// Get the mutex that guards the map for the given type template std::mutex& getMutex() const { diff --git a/include/podio/ROOTNTupleReader.h b/include/podio/ROOTNTupleReader.h new file mode 100644 index 000000000..a25f66f2a --- /dev/null +++ b/include/podio/ROOTNTupleReader.h @@ -0,0 +1,104 @@ +#ifndef PODIO_ROOTNTUPLEREADER_H +#define PODIO_ROOTNTUPLEREADER_H + +#include "podio/CollectionBranches.h" +#include "podio/ICollectionProvider.h" +#include "podio/ROOTFrameData.h" +#include "podio/SchemaEvolution.h" +#include "podio/podioVersion.h" +#include "podio/utilities/DatamodelRegistryIOHelpers.h" + +#include +#include +#include + +#include +#include + +namespace podio { + +/** +This class has the function to read available data from disk +and to prepare collections and buffers. +**/ +class ROOTNTupleReader { + +public: + ROOTNTupleReader() = default; + ~ROOTNTupleReader() = default; + + ROOTNTupleReader(const ROOTNTupleReader&) = delete; + ROOTNTupleReader& operator=(const ROOTNTupleReader&) = delete; + + void openFile(const std::string& filename); + void openFiles(const std::vector& filename); + + /** + * Read the next data entry from which a Frame can be constructed for the + * given name. In case there are no more entries left for this name or in + * case there is no data for this name, this returns a nullptr. + */ + std::unique_ptr readNextEntry(const std::string& name); + + /** + * Read the specified data entry from which a Frame can be constructed for + * the given name. In case the entry does not exist for this name or in case + * there is no data for this name, this returns a nullptr. + */ + std::unique_ptr readEntry(const std::string& name, const unsigned entry); + + /// Returns number of entries for the given name + unsigned getEntries(const std::string& name); + + /// Get the build version of podio that has been used to write the current file + podio::version::Version currentFileVersion() const { + return m_fileVersion; + } + + void closeFile(); + +private: + /** + * Initialize the given category by filling the maps with metadata information + * that will be used later + */ + bool initCategory(const std::string& category); + + /** + * Read and reconstruct the generic parameters of the Frame + */ + GenericParameters readEventMetaData(const std::string& name, unsigned entNum); + + template + void readParams(const std::string& name, unsigned entNum, GenericParameters& params); + + std::unique_ptr m_metadata{}; + + podio::version::Version m_fileVersion{}; + DatamodelDefinitionHolder m_datamodelHolder{}; + + std::unordered_map>> m_readers{}; + std::unordered_map> m_metadata_readers{}; + std::vector m_filenames{}; + + std::unordered_map m_entries{}; + std::unordered_map m_totalEntries{}; + + struct CollectionInfo { + std::vector id{}; + std::vector name{}; + std::vector type{}; + std::vector isSubsetCollection{}; + std::vector schemaVersion{}; + }; + + std::unordered_map m_collectionInfo{}; + + std::vector m_availableCategories{}; + + std::shared_ptr m_table{}; +}; + +} // namespace podio + +#endif diff --git a/include/podio/ROOTNTupleWriter.h b/include/podio/ROOTNTupleWriter.h new file mode 100644 index 000000000..0f6f6d466 --- /dev/null +++ b/include/podio/ROOTNTupleWriter.h @@ -0,0 +1,74 @@ +#ifndef PODIO_ROOTNTUPLEWRITER_H +#define PODIO_ROOTNTUPLEWRITER_H + +#include "podio/CollectionBase.h" +#include "podio/Frame.h" +#include "podio/GenericParameters.h" +#include "podio/SchemaEvolution.h" +#include "podio/utilities/DatamodelRegistryIOHelpers.h" + +#include "TFile.h" +#include +#include + +#include +#include +#include + +namespace podio { + +class ROOTNTupleWriter { +public: + ROOTNTupleWriter(const std::string& filename); + ~ROOTNTupleWriter(); + + ROOTNTupleWriter(const ROOTNTupleWriter&) = delete; + ROOTNTupleWriter& operator=(const ROOTNTupleWriter&) = delete; + + template + void fillParams(GenericParameters& params, ROOT::Experimental::REntry* entry); + + void writeFrame(const podio::Frame& frame, const std::string& category); + void writeFrame(const podio::Frame& frame, const std::string& category, const std::vector& collsToWrite); + void finish(); + +private: + using StoreCollection = std::pair; + std::unique_ptr createModels(const std::vector& collections); + + std::unique_ptr m_metadata{}; + std::unordered_map> m_writers{}; + std::unique_ptr m_metadataWriter{}; + + std::unique_ptr m_file{}; + + DatamodelDefinitionCollector m_datamodelCollector{}; + + struct CollectionInfo { + std::vector id{}; + std::vector name{}; + std::vector type{}; + std::vector isSubsetCollection{}; + std::vector schemaVersion{}; + }; + + std::unordered_map m_collectionInfo{}; + + std::set m_categories{}; + + bool m_finished{false}; + + std::vector m_intkeys{}, m_floatkeys{}, m_doublekeys{}, m_stringkeys{}; + + std::vector> m_intvalues{}; + std::vector> m_floatvalues{}; + std::vector> m_doublevalues{}; + std::vector> m_stringvalues{}; + + template + std::pair&, std::vector>&> getKeyValueVectors(); +}; + +} // namespace podio + +#endif // PODIO_ROOTNTUPLEWRITER_H diff --git a/include/podio/UserDataCollection.h b/include/podio/UserDataCollection.h index b9aefaf40..cc5b7154f 100644 --- a/include/podio/UserDataCollection.h +++ b/include/podio/UserDataCollection.h @@ -123,7 +123,7 @@ class UserDataCollection : public CollectionBase { /// Get the collection buffers for this collection podio::CollectionWriteBuffers getBuffers() override { _vecPtr = &_vec; // Set the pointer to the correct internal vector - return {&_vecPtr, &m_refCollections, &m_vecmem_info}; + return {&_vecPtr, _vecPtr, &m_refCollections, &m_vecmem_info}; } /// check for validity of the container after read diff --git a/python/templates/CollectionData.cc.jinja2 b/python/templates/CollectionData.cc.jinja2 index 3ae5d3a80..3946ad756 100644 --- a/python/templates/CollectionData.cc.jinja2 +++ b/python/templates/CollectionData.cc.jinja2 @@ -92,6 +92,7 @@ podio::CollectionWriteBuffers {{ class_type }}::getCollectionBuffers(bool isSubs return { isSubsetColl ? nullptr : (void*)&m_data, + isSubsetColl ? nullptr : (void*)m_data.get(), &m_refCollections, // only need to store the ObjectIDs of the referenced objects &m_vecmem_info }; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 394f23057..daea12e5f 100755 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -34,6 +34,9 @@ FUNCTION(PODIO_ADD_LIB_AND_DICT libname headers sources selection ) $ $) target_link_libraries(${dictname} PUBLIC podio::${libname} podio::podio ROOT::Core ROOT::Tree) + if(ENABLE_RNTUPLE) + target_link_libraries(${dictname} PUBLIC ROOT::ROOTNTuple) + endif() PODIO_GENERATE_DICTIONARY(${dictname} ${headers} SELECTION ${selection} OPTIONS --library ${CMAKE_SHARED_LIBRARY_PREFIX}${dictname}${CMAKE_SHARED_LIBRARY_SUFFIX} ) @@ -83,15 +86,30 @@ SET(root_sources ROOTFrameReader.cc ROOTLegacyReader.cc ) +if(ENABLE_RNTUPLE) + list(APPEND root_sources + ROOTNTupleReader.cc + ROOTNTupleWriter.cc + ) +endif() SET(root_headers ${CMAKE_SOURCE_DIR}/include/podio/ROOTFrameReader.h ${CMAKE_SOURCE_DIR}/include/podio/ROOTLegacyReader.h ${CMAKE_SOURCE_DIR}/include/podio/ROOTFrameWriter.h ) +if(ENABLE_RNTUPLE) + list(APPEND root_headers + ${CMAKE_SOURCE_DIR}/include/podio/ROOTNTupleReader.h + ${CMAKE_SOURCE_DIR}/include/podio/ROOTNTupleWriter.h + ) +endif() PODIO_ADD_LIB_AND_DICT(podioRootIO "${root_headers}" "${root_sources}" root_selection.xml) target_link_libraries(podioRootIO PUBLIC podio::podio ROOT::Core ROOT::RIO ROOT::Tree) +if(ENABLE_RNTUPLE) + target_link_libraries(podioRootIO PUBLIC ROOT::ROOTNTuple) +endif() # --- Python EventStore for enabling (legacy) python bindings diff --git a/src/ROOTNTupleReader.cc b/src/ROOTNTupleReader.cc new file mode 100644 index 000000000..299f88da3 --- /dev/null +++ b/src/ROOTNTupleReader.cc @@ -0,0 +1,186 @@ +#include "podio/ROOTNTupleReader.h" +#include "podio/CollectionBase.h" +#include "podio/CollectionBufferFactory.h" +#include "podio/CollectionBuffers.h" +#include "podio/CollectionIDTable.h" +#include "podio/DatamodelRegistry.h" +#include "podio/GenericParameters.h" +#include "rootUtils.h" + +#include "TClass.h" +#include +#include + +namespace podio { + +template +void ROOTNTupleReader::readParams(const std::string& name, unsigned entNum, GenericParameters& params) { + auto keyView = m_readers[name][0]->GetView>(root_utils::getGPKeyName()); + auto valueView = m_readers[name][0]->GetView>>(root_utils::getGPValueName()); + + for (size_t i = 0; i < keyView(entNum).size(); ++i) { + params.getMap().emplace(std::move(keyView(entNum)[i]), std::move(valueView(entNum)[i])); + } +} + +GenericParameters ROOTNTupleReader::readEventMetaData(const std::string& name, unsigned entNum) { + GenericParameters params; + + readParams(name, entNum, params); + readParams(name, entNum, params); + readParams(name, entNum, params); + readParams(name, entNum, params); + + return params; +} + +bool ROOTNTupleReader::initCategory(const std::string& category) { + if (std::find(m_availableCategories.begin(), m_availableCategories.end(), category) == m_availableCategories.end()) { + return false; + } + // Assume that the metadata is the same in all files + auto filename = m_filenames[0]; + + auto id = m_metadata_readers[filename]->GetView>(root_utils::idTableName(category)); + m_collectionInfo[category].id = id(0); + + auto collectionName = + m_metadata_readers[filename]->GetView>(root_utils::collectionName(category)); + m_collectionInfo[category].name = collectionName(0); + + auto collectionType = + m_metadata_readers[filename]->GetView>(root_utils::collInfoName(category)); + m_collectionInfo[category].type = collectionType(0); + + auto subsetCollection = + m_metadata_readers[filename]->GetView>(root_utils::subsetCollection(category)); + m_collectionInfo[category].isSubsetCollection = subsetCollection(0); + + auto schemaVersion = m_metadata_readers[filename]->GetView>("schemaVersion_" + category); + m_collectionInfo[category].schemaVersion = schemaVersion(0); + + return true; +} + +void ROOTNTupleReader::openFile(const std::string& filename) { + openFiles({filename}); +} + +void ROOTNTupleReader::openFiles(const std::vector& filenames) { + + m_filenames.insert(m_filenames.end(), filenames.begin(), filenames.end()); + for (auto& filename : filenames) { + if (m_metadata_readers.find(filename) == m_metadata_readers.end()) { + m_metadata_readers[filename] = ROOT::Experimental::RNTupleReader::Open(root_utils::metaTreeName, filename); + } + } + + m_metadata = ROOT::Experimental::RNTupleReader::Open(root_utils::metaTreeName, filenames[0]); + + auto versionView = m_metadata->GetView>(root_utils::versionBranchName); + auto version = versionView(0); + + m_fileVersion = podio::version::Version{version[0], version[1], version[2]}; + + auto edmView = m_metadata->GetView>>(root_utils::edmDefBranchName); + auto edm = edmView(0); + + auto availableCategoriesField = m_metadata->GetView>(root_utils::availableCategories); + m_availableCategories = availableCategoriesField(0); +} + +unsigned ROOTNTupleReader::getEntries(const std::string& name) { + if (m_readers.find(name) == m_readers.end()) { + for (auto& filename : m_filenames) { + try { + m_readers[name].emplace_back(ROOT::Experimental::RNTupleReader::Open(name, filename)); + } catch (const ROOT::Experimental::RException& e) { + std::cout << "Category " << name << " not found in file " << filename << std::endl; + } + } + m_totalEntries[name] = std::accumulate(m_readers[name].begin(), m_readers[name].end(), 0, + [](int total, auto& reader) { return total + reader->GetNEntries(); }); + } + return m_totalEntries[name]; +} + +std::unique_ptr ROOTNTupleReader::readNextEntry(const std::string& name) { + return readEntry(name, m_entries[name]); +} + +std::unique_ptr ROOTNTupleReader::readEntry(const std::string& category, const unsigned entNum) { + if (m_totalEntries.find(category) == m_totalEntries.end()) { + getEntries(category); + } + if (entNum >= m_totalEntries[category]) { + return nullptr; + } + + if (m_collectionInfo.find(category) == m_collectionInfo.end()) { + if (!initCategory(category)) { + return nullptr; + } + } + + m_entries[category] = entNum + 1; + + ROOTFrameData::BufferMap buffers; + auto dentry = m_readers[category][0]->GetModel()->GetDefaultEntry(); + + for (size_t i = 0; i < m_collectionInfo[category].id.size(); ++i) { + const auto collectionClass = TClass::GetClass(m_collectionInfo[category].type[i].c_str()); + + auto collection = + std::unique_ptr(static_cast(collectionClass->New())); + + const auto& bufferFactory = podio::CollectionBufferFactory::instance(); + auto maybeBuffers = + bufferFactory.createBuffers(m_collectionInfo[category].type[i], m_collectionInfo[category].schemaVersion[i], + m_collectionInfo[category].isSubsetCollection[i]); + auto collBuffers = maybeBuffers.value_or(podio::CollectionReadBuffers{}); + + if (!maybeBuffers) { + std::cout << "WARNING: Buffers couldn't be created for collection " << m_collectionInfo[category].name[i] + << " of type " << m_collectionInfo[category].type[i] << " and schema version " + << m_collectionInfo[category].schemaVersion[i] << std::endl; + return nullptr; + } + + if (m_collectionInfo[category].isSubsetCollection[i]) { + auto brName = root_utils::subsetBranch(m_collectionInfo[category].name[i]); + auto vec = new std::vector; + dentry->CaptureValueUnsafe(brName, vec); + collBuffers.references->at(0) = std::unique_ptr>(vec); + } else { + dentry->CaptureValueUnsafe(m_collectionInfo[category].name[i], collBuffers.data); + + const auto relVecNames = podio::DatamodelRegistry::instance().getRelationNames(collection->getTypeName()); + for (size_t j = 0; j < relVecNames.relations.size(); ++j) { + const auto relName = relVecNames.relations[j]; + auto vec = new std::vector; + const auto brName = root_utils::refBranch(m_collectionInfo[category].name[i], relName); + dentry->CaptureValueUnsafe(brName, vec); + collBuffers.references->at(j) = std::unique_ptr>(vec); + } + + for (size_t j = 0; j < relVecNames.vectorMembers.size(); ++j) { + const auto vecName = relVecNames.vectorMembers[j]; + const auto brName = root_utils::vecBranch(m_collectionInfo[category].name[i], vecName); + dentry->CaptureValueUnsafe(brName, collBuffers.vectorMembers->at(j).second); + } + } + + buffers.emplace(m_collectionInfo[category].name[i], std::move(collBuffers)); + } + + m_readers[category][0]->LoadEntry(entNum); + + auto parameters = readEventMetaData(category, entNum); + if (!m_table) { + m_table = std::make_shared(m_collectionInfo[category].id, m_collectionInfo[category].name); + } + + return std::make_unique(std::move(buffers), m_table, std::move(parameters)); +} + +} // namespace podio diff --git a/src/ROOTNTupleWriter.cc b/src/ROOTNTupleWriter.cc new file mode 100644 index 000000000..741af53b3 --- /dev/null +++ b/src/ROOTNTupleWriter.cc @@ -0,0 +1,263 @@ +#include "podio/ROOTNTupleWriter.h" +#include "podio/CollectionBase.h" +#include "podio/DatamodelRegistry.h" +#include "podio/GenericParameters.h" +#include "podio/SchemaEvolution.h" +#include "podio/podioVersion.h" +#include "rootUtils.h" + +#include "TFile.h" +#include +#include +#include + +#include + +namespace podio { + +ROOTNTupleWriter::ROOTNTupleWriter(const std::string& filename) : + m_metadata(ROOT::Experimental::RNTupleModel::Create()), + m_file(new TFile(filename.c_str(), "RECREATE", "data file")) { +} + +ROOTNTupleWriter::~ROOTNTupleWriter() { + if (!m_finished) { + finish(); + } +} + +template +std::pair&, std::vector>&> ROOTNTupleWriter::getKeyValueVectors() { + if constexpr (std::is_same_v) { + return {m_intkeys, m_intvalues}; + } else if constexpr (std::is_same_v) { + return {m_floatkeys, m_floatvalues}; + } else if constexpr (std::is_same_v) { + return {m_doublekeys, m_doublevalues}; + } else if constexpr (std::is_same_v) { + return {m_stringkeys, m_stringvalues}; + } else { + throw std::runtime_error("Unknown type"); + } +} + +template +void ROOTNTupleWriter::fillParams(GenericParameters& params, ROOT::Experimental::REntry* entry) { + auto [key, value] = getKeyValueVectors(); + entry->CaptureValueUnsafe(root_utils::getGPKeyName(), &key); + entry->CaptureValueUnsafe(root_utils::getGPValueName(), &value); + + key.clear(); + key.reserve(params.getMap().size()); + value.clear(); + value.reserve(params.getMap().size()); + + for (auto& [kk, vv] : params.getMap()) { + key.emplace_back(kk); + value.emplace_back(vv); + } +} + +void ROOTNTupleWriter::writeFrame(const podio::Frame& frame, const std::string& category) { + writeFrame(frame, category, frame.getAvailableCollections()); +} + +void ROOTNTupleWriter::writeFrame(const podio::Frame& frame, const std::string& category, + const std::vector& collsToWrite) { + + std::vector collections; + collections.reserve(collsToWrite.size()); + for (const auto& name : collsToWrite) { + auto* coll = frame.getCollectionForWrite(name); + collections.emplace_back(name, const_cast(coll)); + } + + bool new_category = false; + if (m_writers.find(category) == m_writers.end()) { + new_category = true; + auto model = createModels(collections); + m_writers[category] = ROOT::Experimental::RNTupleWriter::Append(std::move(model), category, *m_file.get(), {}); + } + + auto entry = m_writers[category]->GetModel()->CreateBareEntry(); + + ROOT::Experimental::RNTupleWriteOptions options; + options.SetCompression(ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose); + + for (const auto& [name, coll] : collections) { + auto collBuffers = coll->getBuffers(); + if (collBuffers.vecPtr) { + entry->CaptureValueUnsafe(name, (void*)collBuffers.vecPtr); + } + + if (coll->isSubsetCollection()) { + auto& refColl = (*collBuffers.references)[0]; + const auto brName = root_utils::subsetBranch(name); + entry->CaptureValueUnsafe(brName, refColl.get()); + } else { + + const auto relVecNames = podio::DatamodelRegistry::instance().getRelationNames(coll->getValueTypeName()); + if (auto refColls = collBuffers.references) { + int i = 0; + for (auto& c : (*refColls)) { + const auto brName = root_utils::refBranch(name, relVecNames.relations[i]); + entry->CaptureValueUnsafe(brName, c.get()); + ++i; + } + } + + if (auto vmInfo = collBuffers.vectorMembers) { + int i = 0; + for (auto& [type, vec] : (*vmInfo)) { + const auto typeName = "vector<" + type + ">"; + const auto brName = root_utils::vecBranch(name, relVecNames.vectorMembers[i]); + auto ptr = *(std::vector**)vec; + entry->CaptureValueUnsafe(brName, ptr); + ++i; + } + } + } + + // Not supported + // entry->CaptureValueUnsafe(root_utils::paramBranchName, + // &const_cast(frame.getParameters())); + + if (new_category) { + m_collectionInfo[category].id.emplace_back(coll->getID()); + m_collectionInfo[category].name.emplace_back(name); + m_collectionInfo[category].type.emplace_back(coll->getTypeName()); + m_collectionInfo[category].isSubsetCollection.emplace_back(coll->isSubsetCollection()); + m_collectionInfo[category].schemaVersion.emplace_back(coll->getSchemaVersion()); + } + } + + auto params = frame.getParameters(); + fillParams(params, entry.get()); + fillParams(params, entry.get()); + fillParams(params, entry.get()); + fillParams(params, entry.get()); + + m_writers[category]->Fill(*entry); + m_categories.insert(category); +} + +std::unique_ptr +ROOTNTupleWriter::createModels(const std::vector& collections) { + auto model = ROOT::Experimental::RNTupleModel::CreateBare(); + for (auto& [name, coll] : collections) { + const auto collBuffers = coll->getBuffers(); + + if (collBuffers.vecPtr) { + auto collClassName = "std::vector<" + std::string(coll->getDataTypeName()) + ">"; + auto field = ROOT::Experimental::Detail::RFieldBase::Create(name, collClassName).Unwrap(); + model->AddField(std::move(field)); + } + + if (coll->isSubsetCollection()) { + const auto brName = root_utils::subsetBranch(name); + auto collClassName = "vector"; + auto field = ROOT::Experimental::Detail::RFieldBase::Create(brName, collClassName).Unwrap(); + model->AddField(std::move(field)); + } else { + + const auto relVecNames = podio::DatamodelRegistry::instance().getRelationNames(coll->getValueTypeName()); + if (auto refColls = collBuffers.references) { + int i = 0; + for (auto& c [[maybe_unused]] : (*refColls)) { + const auto brName = root_utils::refBranch(name, relVecNames.relations[i]); + auto collClassName = "vector"; + auto field = ROOT::Experimental::Detail::RFieldBase::Create(brName, collClassName).Unwrap(); + model->AddField(std::move(field)); + ++i; + } + } + + if (auto vminfo = collBuffers.vectorMembers) { + int i = 0; + for (auto& [type, vec] : (*vminfo)) { + const auto typeName = "vector<" + type + ">"; + const auto brName = root_utils::vecBranch(name, relVecNames.vectorMembers[i]); + auto field = ROOT::Experimental::Detail::RFieldBase::Create(brName, typeName).Unwrap(); + model->AddField(std::move(field)); + ++i; + } + } + } + } + + // Not supported by ROOT because podio::GenericParameters has map types + // so we have to split them manually + // model->MakeField(root_utils::paramBranchName); + + model->AddField( + ROOT::Experimental::Detail::RFieldBase::Create(root_utils::intKeyName, "std::vector>").Unwrap()); + model->AddField( + ROOT::Experimental::Detail::RFieldBase::Create(root_utils::floatKeyName, "std::vector>").Unwrap()); + model->AddField( + ROOT::Experimental::Detail::RFieldBase::Create(root_utils::doubleKeyName, "std::vector>").Unwrap()); + model->AddField( + ROOT::Experimental::Detail::RFieldBase::Create(root_utils::stringKeyName, "std::vector>").Unwrap()); + + model->AddField( + ROOT::Experimental::Detail::RFieldBase::Create(root_utils::intValueName, "std::vector>") + .Unwrap()); + model->AddField( + ROOT::Experimental::Detail::RFieldBase::Create(root_utils::floatValueName, "std::vector>") + .Unwrap()); + model->AddField( + ROOT::Experimental::Detail::RFieldBase::Create(root_utils::doubleValueName, "std::vector>") + .Unwrap()); + model->AddField(ROOT::Experimental::Detail::RFieldBase::Create(root_utils::stringValueName, + "std::vector>") + .Unwrap()); + + model->Freeze(); + return model; +} + +void ROOTNTupleWriter::finish() { + + auto podioVersion = podio::version::build_version; + auto versionField = m_metadata->MakeField>(root_utils::versionBranchName); + *versionField = {podioVersion.major, podioVersion.minor, podioVersion.patch}; + + auto edmDefinitions = m_datamodelCollector.getDatamodelDefinitionsToWrite(); + auto edmField = + m_metadata->MakeField>>(root_utils::edmDefBranchName); + *edmField = edmDefinitions; + + auto availableCategoriesField = m_metadata->MakeField>(root_utils::availableCategories); + for (auto& [c, _] : m_collectionInfo) { + availableCategoriesField->push_back(c); + } + + for (auto& category : m_categories) { + auto idField = m_metadata->MakeField>({root_utils::idTableName(category)}); + *idField = m_collectionInfo[category].id; + auto collectionNameField = m_metadata->MakeField>({root_utils::collectionName(category)}); + *collectionNameField = m_collectionInfo[category].name; + auto collectionTypeField = m_metadata->MakeField>({root_utils::collInfoName(category)}); + *collectionTypeField = m_collectionInfo[category].type; + auto subsetCollectionField = m_metadata->MakeField>({root_utils::subsetCollection(category)}); + *subsetCollectionField = m_collectionInfo[category].isSubsetCollection; + auto schemaVersionField = m_metadata->MakeField>({"schemaVersion_" + category}); + *schemaVersionField = m_collectionInfo[category].schemaVersion; + } + + m_metadata->Freeze(); + m_metadataWriter = + ROOT::Experimental::RNTupleWriter::Append(std::move(m_metadata), root_utils::metaTreeName, *m_file, {}); + + m_metadataWriter->Fill(); + + m_file->Write(); + + // All the tuple writers must be deleted before the file so that they flush + // unwritten output + m_writers.clear(); + m_metadataWriter.reset(); + + m_finished = true; +} + +} // namespace podio diff --git a/src/rootUtils.h b/src/rootUtils.h index 507d24b15..523d5c228 100644 --- a/src/rootUtils.h +++ b/src/rootUtils.h @@ -32,6 +32,75 @@ constexpr static auto metaTreeName = "podio_metadata"; */ constexpr static auto paramBranchName = "PARAMETERS"; +/** + * Names of the fields with the keys and values of the generic parameters for + * the RNTuples until map types are supported + */ +constexpr static auto intKeyName = "GPIntKeys"; +constexpr static auto floatKeyName = "GPFloatKeys"; +constexpr static auto doubleKeyName = "GPDoubleKeys"; +constexpr static auto stringKeyName = "GPStringKeys"; + +constexpr static auto intValueName = "GPIntValues"; +constexpr static auto floatValueName = "GPFloatValues"; +constexpr static auto doubleValueName = "GPDoubleValues"; +constexpr static auto stringValueName = "GPStringValues"; + +/** + * Get the name of the key depending on the type + */ +template +constexpr auto getGPKeyName() { + if constexpr (std::is_same::value) { + return intKeyName; + } else if constexpr (std::is_same::value) { + return floatKeyName; + } else if constexpr (std::is_same::value) { + return doubleKeyName; + } else if constexpr (std::is_same::value) { + return stringKeyName; + } else { + static_assert(sizeof(T) == 0, "Unsupported type for generic parameters"); + } +} + +/** + * Get the name of the value depending on the type + */ +template +constexpr auto getGPValueName() { + if constexpr (std::is_same::value) { + return intValueName; + } else if constexpr (std::is_same::value) { + return floatValueName; + } else if constexpr (std::is_same::value) { + return doubleValueName; + } else if constexpr (std::is_same::value) { + return stringValueName; + } else { + static_assert(sizeof(T) == 0, "Unsupported type for generic parameters"); + } +} + +/** + * Name of the field with the list of categories for RNTuples + */ +constexpr static auto availableCategories = "availableCategories"; + +/** + * Name of the field with the names of the collections for RNTuples + */ +inline std::string collectionName(const std::string& category) { + return category + "_collectionNames"; +} + +/** + * Name of the field with the flag for subset collections for RNTuples + */ +inline std::string subsetCollection(const std::string& category) { + return category + "_isSubsetCollections"; +} + /** * The name of the branch into which we store the build version of podio at the * time of writing the file diff --git a/tests/root_io/CMakeLists.txt b/tests/root_io/CMakeLists.txt index b3bd1b575..ed4b6191f 100644 --- a/tests/root_io/CMakeLists.txt +++ b/tests/root_io/CMakeLists.txt @@ -12,6 +12,13 @@ set(root_dependent_tests read_frame_legacy_root.cpp read_frame_root_multiple.cpp ) +if(ENABLE_RNTUPLE) + set(root_dependent_tests + ${root_dependent_tests} + write_rntuple.cpp + read_rntuple.cpp + ) +endif() set(root_libs podio::podioRootIO) foreach( sourcefile ${root_dependent_tests} ) CREATE_PODIO_TEST(${sourcefile} "${root_libs}") @@ -27,6 +34,9 @@ set_property(TEST read_frame_legacy_root PROPERTY DEPENDS write) set_property(TEST read_timed PROPERTY DEPENDS write_timed) set_property(TEST read_frame_root PROPERTY DEPENDS write_frame_root) set_property(TEST read_frame_root_multiple PROPERTY DEPENDS write_frame_root) +if(ENABLE_RNTUPLE) + set_property(TEST read_rntuple PROPERTY DEPENDS write_rntuple) +endif() add_test(NAME check_benchmark_outputs COMMAND check_benchmark_outputs write_benchmark_root.root read_benchmark_root.root) set_property(TEST check_benchmark_outputs PROPERTY DEPENDS read_timed write_timed) diff --git a/tests/root_io/read_rntuple.cpp b/tests/root_io/read_rntuple.cpp new file mode 100644 index 000000000..59688b2f2 --- /dev/null +++ b/tests/root_io/read_rntuple.cpp @@ -0,0 +1,6 @@ +#include "podio/ROOTNTupleReader.h" +#include "read_frame.h" + +int main() { + return read_frames("example_rntuple.root"); +} diff --git a/tests/root_io/write_rntuple.cpp b/tests/root_io/write_rntuple.cpp new file mode 100644 index 000000000..ce7810c53 --- /dev/null +++ b/tests/root_io/write_rntuple.cpp @@ -0,0 +1,6 @@ +#include "podio/ROOTNTupleWriter.h" +#include "write_frame.h" + +int main() { + write_frames("example_rntuple.root"); +}