Skip to content

Commit

Permalink
[C++] Enable using struct and array of struct as key (#7741)
Browse files Browse the repository at this point in the history
* add unit tests for support struct as key

* make changes to parser and add helper function to generate comparator for struct

* implement

* add more unit tests

* format

* just a test

* test done

* rerun generator

* restore build file

* address comment

* format

* rebase

* rebase

* add more unit tests

* rerun generator

* address some comments

* address comment

* update

* format

* address comment

Co-authored-by: Wen Sun <sunwen@google.com>
Co-authored-by: Derek Bailey <derekbailey@google.com>
  • Loading branch information
3 people authored Jan 25, 2023
1 parent ee848a0 commit 802a3a0
Show file tree
Hide file tree
Showing 7 changed files with 853 additions and 37 deletions.
137 changes: 115 additions & 22 deletions src/idl_gen_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2245,54 +2245,147 @@ class CppGenerator : public BaseGenerator {
}
}

void GenComparatorForStruct(const StructDef &struct_def, size_t space_size,
const std::string lhs_struct_literal,
const std::string rhs_struct_literal) {
code_.SetValue("LHS_PREFIX", lhs_struct_literal);
code_.SetValue("RHS_PREFIX", rhs_struct_literal);
std::string space(space_size, ' ');
for (const auto &curr_field : struct_def.fields.vec) {
const auto curr_field_name = Name(*curr_field);
code_.SetValue("CURR_FIELD_NAME", curr_field_name);
code_.SetValue("LHS", lhs_struct_literal + "_" + curr_field_name);
code_.SetValue("RHS", rhs_struct_literal + "_" + curr_field_name);
const bool is_scalar = IsScalar(curr_field->value.type.base_type);
const bool is_array = IsArray(curr_field->value.type);
const bool is_struct = IsStruct(curr_field->value.type);

// If encouter a key field, call KeyCompareWithValue to compare this field.
if (curr_field->key) {
code_ +=
space + "const auto {{RHS}} = {{RHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
code_ += space + "const auto {{CURR_FIELD_NAME}}_compare_result = {{LHS_PREFIX}}.KeyCompareWithValue({{RHS}});";

code_ += space + "if ({{CURR_FIELD_NAME}}_compare_result != 0)";
code_ += space + " return {{CURR_FIELD_NAME}}_compare_result;";
continue;
}

code_ +=
space + "const auto {{LHS}} = {{LHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
code_ +=
space + "const auto {{RHS}} = {{RHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
if (is_scalar) {
code_ += space + "if ({{LHS}} != {{RHS}})";
code_ += space +
" return static_cast<int>({{LHS}} > {{RHS}}) - "
"static_cast<int>({{LHS}} < {{RHS}});";
} else if (is_array) {
const auto &elem_type = curr_field->value.type.VectorType();
code_ +=
space +
"for (::flatbuffers::uoffset_t i = 0; i < {{LHS}}->size(); i++) {";
code_ += space + " const auto {{LHS}}_elem = {{LHS}}->Get(i);";
code_ += space + " const auto {{RHS}}_elem = {{RHS}}->Get(i);";
if (IsScalar(elem_type.base_type)) {
code_ += space + " if ({{LHS}}_elem != {{RHS}}_elem)";
code_ += space +
" return static_cast<int>({{LHS}}_elem > {{RHS}}_elem) - "
"static_cast<int>({{LHS}}_elem < {{RHS}}_elem);";
code_ += space + "}";

} else if (IsStruct(elem_type)) {
if (curr_field->key) {
code_ += space + "const auto {{CURR_FIELD_NAME}}_compare_result = {{LHS_PREFIX}}.KeyCompareWithValue({{RHS}});";
code_ += space + "if ({{CURR_FIELD_NAME}}_compare_result != 0)";
code_ += space + " return {{CURR_FIELD_NAME}}_compare_result;";
continue;
}
GenComparatorForStruct(
*curr_field->value.type.struct_def, space_size + 2,
code_.GetValue("LHS") + "_elem", code_.GetValue("RHS") + "_elem");

code_ += space + "}";
}

} else if (is_struct) {
GenComparatorForStruct(*curr_field->value.type.struct_def, space_size,
code_.GetValue("LHS"), code_.GetValue("RHS"));
}
}
}

// Generate CompareWithValue method for a key field.
void GenKeyFieldMethods(const FieldDef &field) {
FLATBUFFERS_ASSERT(field.key);
const bool is_string = IsString(field.value.type);
const bool is_array = IsArray(field.value.type);

const bool is_struct = IsStruct(field.value.type);
// Generate KeyCompareLessThan function
code_ +=
" bool KeyCompareLessThan(const {{STRUCT_NAME}} * const o) const {";
if (is_string) {
// use operator< of ::flatbuffers::String
code_ += " return *{{FIELD_NAME}}() < *o->{{FIELD_NAME}}();";
} else if (is_array) {
const auto &elem_type = field.value.type.VectorType();
if (IsScalar(elem_type.base_type)) {
code_ += " return KeyCompareWithValue(o->{{FIELD_NAME}}()) < 0;";
}
} else {
} else if (is_array || is_struct) {
code_ += " return KeyCompareWithValue(o->{{FIELD_NAME}}()) < 0;";
}else {
code_ += " return {{FIELD_NAME}}() < o->{{FIELD_NAME}}();";
}
code_ += " }";

// Generate KeyCompareWithValue function
if (is_string) {
code_ += " int KeyCompareWithValue(const char *_{{FIELD_NAME}}) const {";
code_ += " return strcmp({{FIELD_NAME}}()->c_str(), _{{FIELD_NAME}});";
} else if (is_array) {
const auto &elem_type = field.value.type.VectorType();
std::string input_type = "::flatbuffers::Array<" +
GenTypeGet(elem_type, "", "", " ", false) +
", " + NumToString(elem_type.fixed_length) + ">";
code_.SetValue("INPUT_TYPE", input_type);
code_ +=
" int KeyCompareWithValue(const {{INPUT_TYPE}} *_{{FIELD_NAME}}"
") const {";
code_ +=
" const {{INPUT_TYPE}} *curr_{{FIELD_NAME}} = {{FIELD_NAME}}();";
code_ +=
" for (::flatbuffers::uoffset_t i = 0; i < "
"curr_{{FIELD_NAME}}->size(); i++) {";

if (IsScalar(elem_type.base_type)) {
std::string input_type = "::flatbuffers::Array<" +
GenTypeBasic(elem_type, false) + ", " +
NumToString(elem_type.fixed_length) + ">";
code_.SetValue("INPUT_TYPE", input_type);
code_ +=
" int KeyCompareWithValue(const {{INPUT_TYPE}} *_{{FIELD_NAME}}"
") const {";
code_ +=
" const {{INPUT_TYPE}} *curr_{{FIELD_NAME}} = {{FIELD_NAME}}();";
code_ +=
" for (::flatbuffers::uoffset_t i = 0; i < "
"curr_{{FIELD_NAME}}->size(); i++) {";
code_ += " const auto lhs = curr_{{FIELD_NAME}}->Get(i);";
code_ += " const auto rhs = _{{FIELD_NAME}}->Get(i);";
code_ += " if(lhs != rhs)";
code_ += " if (lhs != rhs)";
code_ +=
" return static_cast<int>(lhs > rhs)"
" - static_cast<int>(lhs < rhs);";
code_ += " }";
code_ += " return 0;";
} else if (IsStruct(elem_type)) {
code_ +=
" const auto &lhs_{{FIELD_NAME}} = "
"*(curr_{{FIELD_NAME}}->Get(i));";
code_ +=
" const auto &rhs_{{FIELD_NAME}} = *(_{{FIELD_NAME}}->Get(i));";
GenComparatorForStruct(*elem_type.struct_def, 6,
"lhs_" + code_.GetValue("FIELD_NAME"),
"rhs_" + code_.GetValue("FIELD_NAME"));
}
code_ += " }";
code_ += " return 0;";
} else if (is_struct) {
const auto *struct_def = field.value.type.struct_def;
code_.SetValue("INPUT_TYPE",
GenTypeGet(field.value.type, "", "", "", false));
code_ +=
" int KeyCompareWithValue(const {{INPUT_TYPE}} &_{{FIELD_NAME}}) "
"const {";
code_ += " const auto &lhs_{{FIELD_NAME}} = {{FIELD_NAME}}();";
code_ += " const auto &rhs_{{FIELD_NAME}} = _{{FIELD_NAME}};";
GenComparatorForStruct(*struct_def, 4,
"lhs_" + code_.GetValue("FIELD_NAME"),
"rhs_" + code_.GetValue("FIELD_NAME"));
code_ += " return 0;";

} else {
FLATBUFFERS_ASSERT(IsScalar(field.value.type.base_type));
auto type = GenTypeBasic(field.value.type, false);
Expand Down
9 changes: 6 additions & 3 deletions src/idl_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ CheckedError Parser::ParseField(StructDef &struct_def) {
ECHECK(ParseType(type));

if (struct_def.fixed) {
if (IsIncompleteStruct(type) ||
if (IsIncompleteStruct(type) ||
(IsArray(type) && IsIncompleteStruct(type.VectorType()))) {
std::string type_name = IsArray(type) ? type.VectorType().struct_def->name : type.struct_def->name;
return Error(std::string("Incomplete type in struct is not allowed, type name: ") + type_name);
Expand Down Expand Up @@ -1072,8 +1072,11 @@ CheckedError Parser::ParseField(StructDef &struct_def) {
if (field->key) {
if (struct_def.has_key) return Error("only one field may be set as 'key'");
struct_def.has_key = true;
auto is_valid = IsScalar(type.base_type) || IsString(type);
if (IsArray(type)) { is_valid |= IsScalar(type.VectorType().base_type); }
auto is_valid = IsScalar(type.base_type) || IsString(type) || IsStruct(type);
if (IsArray(type)) {
is_valid |=
IsScalar(type.VectorType().base_type) || IsStruct(type.VectorType());
}
if (!is_valid) {
return Error(
"'key' field must be string, scalar type or fixed size array of "
Expand Down
28 changes: 28 additions & 0 deletions tests/key_field/key_field_sample.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,40 @@ struct Bar {
b: uint8;
}

struct Color {
rgb: [float:3] (key);
tag: uint8;
}

struct Apple {
tag: uint8;
color: Color(key);
}

struct Fruit {
a: Apple (key);
b: uint8;
}

struct Rice {
origin: [uint8:3];
quantity: uint32;
}

struct Grain {
a: [Rice:3] (key);
tag: uint8;
}

table FooTable {
a: int;
b: int;
c: string (key);
d: [Baz];
e: [Bar];
f: [Apple];
g: [Fruit];
h: [Grain];
}
root_type FooTable;

Loading

0 comments on commit 802a3a0

Please sign in to comment.