diff --git a/vowpalwabbit/core/include/vw/core/error_data.h b/vowpalwabbit/core/include/vw/core/error_data.h index aea7e246b32..18ec7fb92bc 100644 --- a/vowpalwabbit/core/include/vw/core/error_data.h +++ b/vowpalwabbit/core/include/vw/core/error_data.h @@ -25,6 +25,9 @@ ERROR_CODE_DEFINITION( ERROR_CODE_DEFINITION( 13, fb_parser_size_mismatch_ft_names_ft_values, "Size of feature names and feature values do not match. ") ERROR_CODE_DEFINITION(14, unknown_label_type, "Label type in Flatbuffer not understood. ") +ERROR_CODE_DEFINITION(15, fb_parser_span_misaligned, "Input Flatbuffer span is not aligned to an 8-byte boundary. ") +ERROR_CODE_DEFINITION( + 16, fb_parser_span_length_mismatch, "Input Flatbuffer span does not match flatbuffer size prefix. ") // TODO: This is temporary until we switch to the new error handling mechanism. ERROR_CODE_DEFINITION(10000, vw_exception, "vw_exception: ") diff --git a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h index fa181c1ea46..7c5be6cd480 100644 --- a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h +++ b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h @@ -4,7 +4,6 @@ #pragma once -#include "vw/core/api_status.h" #include "vw/core/example.h" #include "vw/core/multi_ex.h" #include "vw/core/shared_data.h" @@ -14,15 +13,21 @@ namespace VW { +namespace experimental +{ class api_status; +} + +using example_sink_f = std::function; namespace parsers { namespace flatbuffer { int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples); -bool read_span_flatbuffer( - VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples); + +int read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, + VW::multi_ex& examples, example_sink_f example_sink = nullptr, VW::experimental::api_status* status = nullptr); class parser { @@ -57,6 +62,19 @@ class parser VW::experimental::api_status* status = nullptr); int get_namespace_index(const Namespace* ns, namespace_index& ni, VW::experimental::api_status* status = nullptr); + inline void reset_active_multi_ex() + { + _multi_ex_index = 0; + _active_multi_ex = false; + _multi_example_object = nullptr; + } + + inline void reset_active_collection() + { + _example_index = 0; + _active_collection = false; + } + void parse_simple_label(shared_data* sd, polylabel* l, reduction_features* red_features, const SimpleLabel* label); void parse_cb_label(polylabel* l, const CBLabel* label); void parse_ccb_label(polylabel* l, const CCBLabel* label); diff --git a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc index f70e61f6a93..966377505f9 100644 --- a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc +++ b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc @@ -5,12 +5,14 @@ #include "vw/fb_parser/parse_example_flatbuffer.h" #include "vw/core/action_score.h" +#include "vw/core/api_status.h" #include "vw/core/best_constant.h" #include "vw/core/cb.h" #include "vw/core/constant.h" #include "vw/core/error_constants.h" #include "vw/core/global_data.h" #include "vw/core/parser.h" +#include "vw/core/scope_exit.h" #include "vw/core/vw.h" #include @@ -43,8 +45,8 @@ int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl return static_cast(status.get_error_code() == VW::experimental::error_code::success); } -bool read_span_flatbuffer( - VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples) +int read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, + VW::multi_ex& examples, example_sink_f example_sink, VW::experimental::api_status* status) { // we expect context to contain a size_prefixed flatbuffer (technically a binary string) // which means: @@ -59,7 +61,6 @@ bool read_span_flatbuffer( // thus context.size() = sizeof(length) + length io_buf unused; - // TODO: How do we report errors out of here? (This is a general API problem with the parsers) size_t address = reinterpret_cast(span); if (address % 8 != 0) { @@ -67,8 +68,8 @@ bool read_span_flatbuffer( sstream << "fb_parser error: flatbuffer data not aligned to 8 bytes" << std::endl; sstream << " span => @" << std::hex << address << std::dec << " % " << 8 << " = " << address % 8 << " (vs desired = " << 0 << ")"; - THROW(sstream.str()); - return false; + + RETURN_ERROR_LS(status, fb_parser_span_misaligned) << sstream.str(); } flatbuffers::uoffset_t flatbuffer_object_size = @@ -79,42 +80,80 @@ bool read_span_flatbuffer( sstream << "fb_parser error: flatbuffer size prefix does not match actual size" << std::endl; sstream << " span => @" << std::hex << address << std::dec << " size_prefix = " << flatbuffer_object_size << " length = " << length; - THROW(sstream.str()); - return false; + + RETURN_ERROR_LS(status, fb_parser_span_length_mismatch) << sstream.str(); } VW::multi_ex temp_ex; - temp_ex.push_back(&example_factory()); + + // Use scope_exit because the parser reports errors by throwing exceptions (the code path in the vw driver + // uses the return value to signal completion, not errors). + auto scope_guard = VW::scope_exit( + [&temp_ex, &all, &example_sink]() + { + if (example_sink == nullptr) { VW::finish_example(*all, temp_ex); } + else { example_sink(std::move(temp_ex)); } + }); + + // There is a bit of unhappiness with the interface of the read_XYZ_() functions, because they often + // expect the input multi_ex to have a single "empty" example there. This contributes, in part, to the large + // proliferation of entry points into the JSON parser(s). We want to avoid exposing that insofar as possible, + // so we will check whether we already received a perfectly good example and use that, or create a new one if + // needed. + if (examples.size() > 0) + { + assert(examples.size() == 1); + temp_ex.push_back(examples[0]); + examples.pop_back(); + } + else { temp_ex.push_back(&example_factory()); } bool has_more = true; - VW::experimental::api_status status; do { - switch (all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, &status)) + switch (int result = all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, status)) { case VW::experimental::error_code::success: has_more = true; break; + // Because nothing_to_parse is not an error we have to filter it out here, otherwise + // we could simply do RETURN_IF_FAIL(result) and let the macro handle it. case VW::experimental::error_code::nothing_to_parse: has_more = false; break; default: - std::stringstream sstream; - sstream << "Error parsing examples: " << std::endl; - THROW(sstream.str()); - return false; + RETURN_IF_FAIL(result); } + // The underlying parser will emit a newline example when terminating the parsing + // of a multi_ex block. Since we are collecting it into a multi_ex, we want to + // swallow it here, but should the parser not have followed its contract w.r.t. + // the return value, we should use the presence of the newline example to override + // has_more. has_more &= !temp_ex[0]->is_newline; + // If this is a real example, we need to move it to the output multi_ex; we also + // need to create a new example to replace it for the next run through the parser. if (!temp_ex[0]->is_newline) { - examples.push_back(&example_factory()); - std::swap(examples[examples.size() - 1], temp_ex[0]); + // We avoid doing moves or copy construction here because multi_ex contains + // example pointers. The compile-time code here is meant to call attention + // to here if the underlying type changes. + using temp_ex_element_t = std::remove_reference::type; + using examples_element_t = std::remove_reference::type; + + static_assert(std::is_same::value && + std::is_same::value, + "temp_ex and example must be vector-like over VW::example*"); + + examples.push_back(temp_ex[0]); + + // Since we are using a vector of pointers, we can simply reassign the slot to + // the pointer of the newly created destination example for the parser. + temp_ex[0] = &example_factory(); } } while (has_more); - VW::finish_example(*all, temp_ex); - return true; + return VW::experimental::error_code::success; } const VW::parsers::flatbuffer::ExampleRoot* parser::data() { return _data; } @@ -198,16 +237,17 @@ int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, { _active_multi_ex = true; _multi_example_object = _data->example_obj_as_ExampleCollection()->multi_examples()->Get(_example_index); + + // read from active multi_ex RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); - // read from active collection + // if we are done with the multi example, move to the next one, or finish the collection if (!_active_multi_ex) { _example_index++; if (_example_index == _data->example_obj_as_ExampleCollection()->multi_examples()->size()) { - _example_index = 0; - _active_collection = false; + reset_active_collection(); } } } @@ -216,11 +256,7 @@ int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, const auto ex = _data->example_obj_as_ExampleCollection()->examples()->Get(_example_index); RETURN_IF_FAIL(parse_example(all, examples[0], ex, status)); _example_index++; - if (_example_index == _data->example_obj_as_ExampleCollection()->examples()->size()) - { - _example_index = 0; - _active_collection = false; - } + if (_example_index == _data->example_obj_as_ExampleCollection()->examples()->size()) { reset_active_collection(); } } return VW::experimental::error_code::success; } @@ -231,6 +267,20 @@ int parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl #define RETURN_SUCCESS_FINISHED() \ return buffer_pointer ? VW::experimental::error_code::nothing_to_parse : VW::experimental::error_code::success; + // If we are re-using a single parser instance across multiple invocations, we need to reset + // the state when we get a new buffer_pointer. Otherwise we may be in the middle of a multi_ex + // or example_collection, and the following parse will attempt to reuse the object references + // from the previous buffer, which may have been deallocated. + // TODO: Rewrite the parser to avoid this convoluted, re-entrant logic. + if (buffer_pointer && _flatbuffer_pointer != buffer_pointer) + { + reset_active_multi_ex(); + reset_active_collection(); + } + + // The ExampleCollection processing code owns dispatching to parse_multi_example to handle + // iteration through the outer collection correctly, thus it must have the first chance to + // incoming parse request. if (_active_collection) { RETURN_IF_FAIL(process_collection_item(all, examples, status)); @@ -307,9 +357,7 @@ int parser::parse_multi_example( { // done with multi example, send a newline example and reset ae->is_newline = true; - _multi_ex_index = 0; - _active_multi_ex = false; - _multi_example_object = nullptr; + reset_active_multi_ex(); return VW::experimental::error_code::success; } @@ -325,30 +373,11 @@ int parser::get_namespace_index(const Namespace* ns, namespace_index& ni, VW::ex ni = static_cast(ns->name()->c_str()[0]); return VW::experimental::error_code::success; } - else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_HASH)) + else { ni = ns->hash(); return VW::experimental::error_code::success; } - - if (_active_collection && _active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_name_hash_missing) - << "Either name or hash field must be specified to get the namespace index in collection item with example " - "index " - << _example_index << "and multi example index " << _multi_ex_index; - } - else if (_active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_name_hash_missing) - << "Either name or hash field must be specified to get the namespace index in multi example index " - << _multi_ex_index; - } - else - { - RETURN_ERROR_LS(status, fb_parser_name_hash_missing) - << "Either name or hash field must be specified to get the namespace index"; - } } bool get_namespace_hash(VW::workspace* all, const Namespace* ns, uint64_t& hash) @@ -462,7 +491,7 @@ int parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* n } else { - if (!has_hashes) { RETURN_NS_PARSER_ERROR(status, fb_parser_name_hash_missing) } + if (!has_hashes) { RETURN_NS_PARSER_ERROR(status, fb_parser_feature_hashes_names_missing) } if (ns->feature_hashes()->size() != ns->feature_values()->size()) { @@ -541,6 +570,7 @@ int parser::parse_flat_label( break; } case Label_NONE: + case Label_no_label: break; default: if (_active_collection && _active_multi_ex) diff --git a/vowpalwabbit/fb_parser/src/parse_label.cc b/vowpalwabbit/fb_parser/src/parse_label.cc index c236747569f..663d54241f1 100644 --- a/vowpalwabbit/fb_parser/src/parse_label.cc +++ b/vowpalwabbit/fb_parser/src/parse_label.cc @@ -3,6 +3,7 @@ // license as described in the file LICENSE. #include "vw/core/action_score.h" +#include "vw/core/api_status.h" #include "vw/core/best_constant.h" #include "vw/core/cb.h" #include "vw/core/constant.h" diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.h b/vowpalwabbit/fb_parser/tests/example_data_generator.h index b474d3b0c44..6b12f9636fe 100644 --- a/vowpalwabbit/fb_parser/tests/example_data_generator.h +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.h @@ -9,13 +9,19 @@ #include "prototype_example_root.h" #include "prototype_label.h" #include "prototype_namespace.h" +#include "vw/common/future_compat.h" #include "vw/common/hash.h" #include "vw/common/random.h" +#include "vw/core/error_constants.h" +#include "vw/fb_parser/generated/example_generated.h" #include USE_PROTOTYPE_MNEMONICS_EX +using namespace flatbuffers; +namespace fb = VW::parsers::flatbuffer; + namespace vwtest { @@ -26,6 +32,10 @@ class example_data_generator static VW::rand_state create_test_random_state(); + inline bool random_bool() { return rng.get_and_update_random() >= 0.5; } + + inline int random_int(int min, int max) { return static_cast(rng.get_and_update_random() * (max - min) + min); } + prototype_namespace_t create_namespace(std::string name, uint8_t numeric_features, uint8_t string_features); prototype_example_t create_simple_example(uint8_t numeric_features, uint8_t string_features); @@ -40,8 +50,86 @@ class example_data_generator prototype_example_collection_t create_simple_log( uint8_t num_examples, uint8_t numeric_features, uint8_t string_features); +public: + enum NamespaceErrors + { + BAD_NAMESPACE_NO_ERROR = 0, + BAD_NAMESPACE_NAME_HASH_MISSING = 1, // not actually possible, due to how fb works + BAD_NAMESPACE_FEATURE_VALUES_MISSING = 2, + BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH = 4, + BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH = 8, + BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING = 16, + }; + + template + Offset create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w); + private: VW::rand_state rng; }; +template +Offset example_data_generator::create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w) +{ + prototype_namespace_t ns = create_namespace("BadNamespace", 1, 1); + if VW_STD17_CONSTEXPR (errors == NamespaceErrors::BAD_NAMESPACE_NO_ERROR) return ns.create_flatbuffer(builder, w); + + constexpr bool include_ns_name_hash = !(errors & NamespaceErrors::BAD_NAMESPACE_NAME_HASH_MISSING); + constexpr bool include_feature_values = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING); + + constexpr bool include_feature_hashes = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING || + // If we want to check for name/value mismatch, then we need to avoid + // including the feature hashes, as they will be used as a backup + errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH); + constexpr bool skip_a_feature_hash = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH); + static_assert(!skip_a_feature_hash || include_feature_hashes, "Cannot skip a feature hash if they are not included"); + + constexpr bool include_feature_names = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING); + constexpr bool skip_a_feature_name = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH); + static_assert(!skip_a_feature_name || include_feature_names, "Cannot skip a feature name if they are not included"); + + std::vector> feature_names; + std::vector feature_values; + std::vector feature_hashes; + + for (size_t i = 0; i < ns.features.size(); i++) + { + const auto& f = ns.features[i]; + + if (include_feature_names && (!skip_a_feature_name || i == 1)) + { + feature_names.push_back(builder.CreateString(f.name)); + } + + if VW_STD17_CONSTEXPR (include_feature_values) feature_values.push_back(f.value); + + if (include_feature_hashes && (!skip_a_feature_hash || i == 0)) { feature_hashes.push_back(f.hash); } + } + + Offset name_offset = Offset(); + if (include_ns_name_hash) { name_offset = builder.CreateString(ns.name); } + + // This function attempts to, insofar as possible, generate a layout that looks like it could have + // been created using the normal serialization code: In this case, that means that the strings for + // the feature names are serialized into the builder before a call to CreateNamespaceDirect is made, + // which is where the feature_names vector is allocated. + Offset>> feature_names_offset = + include_feature_names ? builder.CreateVector(feature_names) : Offset>>(); + Offset> feature_values_offset = + include_feature_values ? builder.CreateVector(feature_values) : Offset>(); + Offset> feature_hashes_offset = + include_feature_hashes ? builder.CreateVector(feature_hashes) : Offset>(); + + fb::NamespaceBuilder ns_builder(builder); + + if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_full_hash(VW::hash_space(w, ns.name)); + if VW_STD17_CONSTEXPR (include_feature_hashes) ns_builder.add_feature_hashes(feature_hashes_offset); + if VW_STD17_CONSTEXPR (include_feature_values) ns_builder.add_feature_values(feature_values_offset); + if VW_STD17_CONSTEXPR (include_feature_names) ns_builder.add_feature_names(feature_names_offset); + if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_name(name_offset); + + ns_builder.add_hash(ns.feature_group); + return ns_builder.Finish(); +} + } // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc index 35547b0f43e..4170330d5fd 100644 --- a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc +++ b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc @@ -9,6 +9,7 @@ #include "prototype_namespace.h" #include "vw/common/future_compat.h" #include "vw/common/string_view.h" +#include "vw/core/api_status.h" #include "vw/core/constant.h" #include "vw/core/error_constants.h" #include "vw/core/example.h" @@ -253,7 +254,7 @@ TEST(FlatbufferParser, SingleExample_MissingFeatureIndices) examples.push_back(&VW::get_unused_example(all.get())); VW::io_buf unused_buffer; EXPECT_EQ(all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf), - VW::experimental::error_code::fb_parser_name_hash_missing); + VW::experimental::error_code::fb_parser_feature_hashes_names_missing); EXPECT_EQ(all->parser_runtime.example_parser->reader(all.get(), unused_buffer, examples), 0); auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); diff --git a/vowpalwabbit/fb_parser/tests/prototype_typemappings.h b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h index 455eae702ff..1a5ca09e21b 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_typemappings.h +++ b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h @@ -20,18 +20,24 @@ template <> struct fb_type { using type = VW::parsers::flatbuffer::Example; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_Example; }; template <> struct fb_type { using type = VW::parsers::flatbuffer::MultiExample; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_MultiExample; }; template <> struct fb_type { using type = VW::parsers::flatbuffer::ExampleCollection; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_ExampleCollection; }; using union_t = void; diff --git a/vowpalwabbit/fb_parser/tests/read_span_tests.cc b/vowpalwabbit/fb_parser/tests/read_span_tests.cc index acbee2f529d..66a68960a27 100644 --- a/vowpalwabbit/fb_parser/tests/read_span_tests.cc +++ b/vowpalwabbit/fb_parser/tests/read_span_tests.cc @@ -9,6 +9,7 @@ #include "vw/common/string_view.h" #include "vw/core/constant.h" #include "vw/core/error_constants.h" +#include "vw/core/scope_exit.h" #include "vw/core/vw.h" #include "vw/fb_parser/parse_example_flatbuffer.h" #include "vw/test_common/test_common.h" @@ -66,7 +67,7 @@ inline void verify_multi_ex( } // namespace vwtest template ::type> -void create_flatbuffer_span_and_validate(VW::workspace& w, const T& prototype) +void create_flatbuffer_span_and_validate(VW::workspace& w, vwtest::example_data_generator& data_gen, const T& prototype) { // This is what we expect to see when we use read_span_flatbuffer, since this is intended // to be used for inference, and we would prefer not to force consumers of the API to have @@ -84,6 +85,8 @@ void create_flatbuffer_span_and_validate(VW::workspace& w, const T& prototype) flatbuffers::uoffset_t size = builder.GetSize(); VW::multi_ex parsed_examples; + if (data_gen.random_bool()) { parsed_examples.push_back(&ex_fac()); } + VW::parsers::flatbuffer::read_span_flatbuffer(&w, buffer, size, ex_fac, parsed_examples); verify_multi_ex(w, prototype, parsed_examples); @@ -99,7 +102,7 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_SingleExample) vwtest::prototype_example_t prototype = { {data_gen.create_namespace("A", 3, 4), data_gen.create_namespace("B", 2, 5)}, vwtest::simple_label(1.0f)}; - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } TEST(FlatbufferParser, ReadSpanFlatbuffer_MultiExample) @@ -109,7 +112,7 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_MultiExample) vwtest::example_data_generator data_gen; vwtest::prototype_multiexample_t prototype = data_gen.create_cb_adf_example(3, 1, "tag"); - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionSinglelines) @@ -119,7 +122,7 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionSinglelines) vwtest::example_data_generator data_gen; vwtest::prototype_example_collection_t prototype = data_gen.create_simple_log(3, 3, 4); - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) @@ -129,5 +132,157 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) vwtest::example_data_generator data_gen; vwtest::prototype_example_collection_t prototype = data_gen.create_cb_adf_log(1, 3, 4); - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); +} + +template +void finish_flatbuffer_and_expect_error(FlatBufferBuilder& builder, Offset root, VW::workspace& w) +{ + VW::example_factory_t ex_fac = [&w]() -> VW::example& { return VW::get_unused_example(&w); }; + VW::example_sink_f ex_sink = [&w](VW::multi_ex&& ex) { VW::finish_example(w, ex); }; + if (vwtest::example_data_generator{}.random_bool()) + { + // This is only valid because ex_fac is grabbing an example from the VW example pool + ex_sink = nullptr; + } + + builder.FinishSizePrefixed(root); + + const uint8_t* buffer = builder.GetBufferPointer(); + flatbuffers::uoffset_t size = builder.GetSize(); + + std::vector buffer_copy(buffer, buffer + size); + + VW::multi_ex parsed_examples; + EXPECT_EQ(VW::parsers::flatbuffer::read_span_flatbuffer( + &w, buffer_copy.data(), buffer_copy.size(), ex_fac, parsed_examples, ex_sink), + error_code); +} + +using namespace_factory_f = std::function(FlatBufferBuilder&, VW::workspace&)>; + +Offset create_bad_ns_root_example(FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + std::vector> namespaces = {ns_fac(builder, w)}; + + Offset label_offset = fb::Createno_label(builder).Union(); + return fb::CreateExample(builder, builder.CreateVector(namespaces), fb::Label_no_label, label_offset); +} + +Offset create_bad_ns_root_multiex( + FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + std::vector> examples = {create_bad_ns_root_example(builder, w, ns_fac)}; + + return fb::CreateMultiExample(builder, builder.CreateVector(examples)); +} + +template ::type> +using builder_f = Offset (*)(FlatBufferBuilder&, VW::workspace&, namespace_factory_f); + +template +Offset create_bad_ns_root_collection( + FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + if VW_STD17_CONSTEXPR (multiline) + { + // using "auto" here breaks the code coverage build due to template substitution failure + std::vector> inner_examples = {create_bad_ns_root_multiex(builder, w, ns_fac)}; + return fb::CreateExampleCollection(builder, builder.CreateVector(std::vector>()), + builder.CreateVector(inner_examples), multiline); + } + else + { + // using "auto" here breaks the code coverage build due to template substitution failure + std::vector> inner_examples = {create_bad_ns_root_example(builder, w, ns_fac)}; + return fb::CreateExampleCollection(builder, builder.CreateVector(inner_examples), + builder.CreateVector(std::vector>()), multiline); + } +} + +template +void create_flatbuffer_span_and_expect_error(VW::workspace& w, namespace_factory_f ns_fac, builder_f root_builder) +{ + FlatBufferBuilder builder; + Offset data_obj = root_builder(builder, w, ns_fac).Union(); + + Offset root_obj = fb::CreateExampleRoot(builder, root_type, data_obj); + + finish_flatbuffer_and_expect_error(builder, root_obj, w); +} + +using NamespaceErrors = vwtest::example_data_generator::NamespaceErrors; +template +void run_bad_namespace_test(VW::workspace& w) +{ + vwtest::example_data_generator data_gen; + + static_assert(errors != NamespaceErrors::BAD_NAMESPACE_NO_ERROR, "This test is intended to test bad namespaces"); + namespace_factory_f ns_fac = [&data_gen](FlatBufferBuilder& builder, VW::workspace& w) -> Offset + { return data_gen.create_bad_namespace(builder, w); }; + + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_example); + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_multiex); + + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_collection); + + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_collection); } + +template +void run_bad_namespace_test() +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + run_bad_namespace_test(*all); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesMissing) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING; + constexpr int expected_error_code = err::fb_parser_feature_values_missing; + + run_bad_namespace_test(); +} + +TEST(FlatbufferParser, BadNamespace_FeatureHashesNamesMissing) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING; + constexpr int expected_error_code = err::fb_parser_feature_hashes_names_missing; + + run_bad_namespace_test(); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesHashMismatch) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH; + constexpr int expected_error_code = err::fb_parser_size_mismatch_ft_hashes_ft_values; + + run_bad_namespace_test(); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesNamesMismatch) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH; + constexpr int expected_error_code = err::fb_parser_size_mismatch_ft_names_ft_values; + + run_bad_namespace_test(); +} + +// This test is disabled because it is not possible to create a flatbuffer with a missing namespace name hash. +// TEST(FlatbufferParser, BadNamespace_NameHashMissing) +// { +// namespace err = VW::experimental::error_code; +// constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_NAME_HASH_MISSING; +// constexpr int expected_error_code = err::success; + +// run_bad_namespace_test(); +// }