Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C++] Enable using struct and array of struct as key #7741

Merged
merged 23 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 115 additions & 22 deletions src/idl_gen_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2245,54 +2245,147 @@ class CppGenerator : public BaseGenerator {
}
}

void GenComparatorForStruct(const StructDef &struct_def, size_t space_size,
const std::string lhs_struct_literal,
const std::string rhs_struct_literal) {
code_.SetValue("LHS_PREFIX", lhs_struct_literal);
code_.SetValue("RHS_PREFIX", rhs_struct_literal);
std::string space(space_size, ' ');
for (const auto &curr_field : struct_def.fields.vec) {
const auto curr_field_name = Name(*curr_field);
code_.SetValue("CURR_FIELD_NAME", curr_field_name);
code_.SetValue("LHS", lhs_struct_literal + "_" + curr_field_name);
code_.SetValue("RHS", rhs_struct_literal + "_" + curr_field_name);
const bool is_scalar = IsScalar(curr_field->value.type.base_type);
const bool is_array = IsArray(curr_field->value.type);
const bool is_struct = IsStruct(curr_field->value.type);

// If encouter a key field, call KeyCompareWithValue to compare this field.
if (curr_field->key) {
code_ +=
space + "const auto {{RHS}} = {{RHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
code_ += space + "const auto {{CURR_FIELD_NAME}}_compare_result = {{LHS_PREFIX}}.KeyCompareWithValue({{RHS}});";

code_ += space + "if ({{CURR_FIELD_NAME}}_compare_result != 0)";
code_ += space + " return {{CURR_FIELD_NAME}}_compare_result;";
continue;
}

code_ +=
space + "const auto {{LHS}} = {{LHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
code_ +=
space + "const auto {{RHS}} = {{RHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
if (is_scalar) {
code_ += space + "if ({{LHS}} != {{RHS}})";
code_ += space +
" return static_cast<int>({{LHS}} > {{RHS}}) - "
"static_cast<int>({{LHS}} < {{RHS}});";
} else if (is_array) {
const auto &elem_type = curr_field->value.type.VectorType();
code_ +=
space +
"for (::flatbuffers::uoffset_t i = 0; i < {{LHS}}->size(); i++) {";
code_ += space + " const auto {{LHS}}_elem = {{LHS}}->Get(i);";
code_ += space + " const auto {{RHS}}_elem = {{RHS}}->Get(i);";
if (IsScalar(elem_type.base_type)) {
code_ += space + " if ({{LHS}}_elem != {{RHS}}_elem)";
code_ += space +
" return static_cast<int>({{LHS}}_elem > {{RHS}}_elem) - "
"static_cast<int>({{LHS}}_elem < {{RHS}}_elem);";
code_ += space + "}";

} else if (IsStruct(elem_type)) {
if (curr_field->key) {
code_ += space + "const auto {{CURR_FIELD_NAME}}_compare_result = {{LHS_PREFIX}}.KeyCompareWithValue({{RHS}});";
code_ += space + "if ({{CURR_FIELD_NAME}}_compare_result != 0)";
code_ += space + " return {{CURR_FIELD_NAME}}_compare_result;";
continue;
}
GenComparatorForStruct(
*curr_field->value.type.struct_def, space_size + 2,
code_.GetValue("LHS") + "_elem", code_.GetValue("RHS") + "_elem");

code_ += space + "}";
}

} else if (is_struct) {
GenComparatorForStruct(*curr_field->value.type.struct_def, space_size,
code_.GetValue("LHS"), code_.GetValue("RHS"));
}
}
}

// Generate CompareWithValue method for a key field.
void GenKeyFieldMethods(const FieldDef &field) {
FLATBUFFERS_ASSERT(field.key);
const bool is_string = IsString(field.value.type);
const bool is_array = IsArray(field.value.type);

const bool is_struct = IsStruct(field.value.type);
// Generate KeyCompareLessThan function
code_ +=
" bool KeyCompareLessThan(const {{STRUCT_NAME}} * const o) const {";
if (is_string) {
// use operator< of ::flatbuffers::String
code_ += " return *{{FIELD_NAME}}() < *o->{{FIELD_NAME}}();";
} else if (is_array) {
const auto &elem_type = field.value.type.VectorType();
if (IsScalar(elem_type.base_type)) {
code_ += " return KeyCompareWithValue(o->{{FIELD_NAME}}()) < 0;";
}
} else {
} else if (is_array || is_struct) {
code_ += " return KeyCompareWithValue(o->{{FIELD_NAME}}()) < 0;";
}else {
code_ += " return {{FIELD_NAME}}() < o->{{FIELD_NAME}}();";
}
code_ += " }";

// Generate KeyCompareWithValue function
if (is_string) {
code_ += " int KeyCompareWithValue(const char *_{{FIELD_NAME}}) const {";
code_ += " return strcmp({{FIELD_NAME}}()->c_str(), _{{FIELD_NAME}});";
} else if (is_array) {
const auto &elem_type = field.value.type.VectorType();
std::string input_type = "::flatbuffers::Array<" +
GenTypeGet(elem_type, "", "", " ", false) +
", " + NumToString(elem_type.fixed_length) + ">";
code_.SetValue("INPUT_TYPE", input_type);
code_ +=
" int KeyCompareWithValue(const {{INPUT_TYPE}} *_{{FIELD_NAME}}"
") const {";
code_ +=
" const {{INPUT_TYPE}} *curr_{{FIELD_NAME}} = {{FIELD_NAME}}();";
code_ +=
" for (::flatbuffers::uoffset_t i = 0; i < "
"curr_{{FIELD_NAME}}->size(); i++) {";

if (IsScalar(elem_type.base_type)) {
std::string input_type = "::flatbuffers::Array<" +
GenTypeBasic(elem_type, false) + ", " +
NumToString(elem_type.fixed_length) + ">";
code_.SetValue("INPUT_TYPE", input_type);
code_ +=
" int KeyCompareWithValue(const {{INPUT_TYPE}} *_{{FIELD_NAME}}"
") const {";
code_ +=
" const {{INPUT_TYPE}} *curr_{{FIELD_NAME}} = {{FIELD_NAME}}();";
code_ +=
" for (::flatbuffers::uoffset_t i = 0; i < "
"curr_{{FIELD_NAME}}->size(); i++) {";
code_ += " const auto lhs = curr_{{FIELD_NAME}}->Get(i);";
code_ += " const auto rhs = _{{FIELD_NAME}}->Get(i);";
code_ += " if(lhs != rhs)";
code_ += " if (lhs != rhs)";
code_ +=
" return static_cast<int>(lhs > rhs)"
" - static_cast<int>(lhs < rhs);";
code_ += " }";
code_ += " return 0;";
} else if (IsStruct(elem_type)) {
code_ +=
" const auto &lhs_{{FIELD_NAME}} = "
"*(curr_{{FIELD_NAME}}->Get(i));";
code_ +=
" const auto &rhs_{{FIELD_NAME}} = *(_{{FIELD_NAME}}->Get(i));";
GenComparatorForStruct(*elem_type.struct_def, 6,
"lhs_" + code_.GetValue("FIELD_NAME"),
"rhs_" + code_.GetValue("FIELD_NAME"));
}
code_ += " }";
code_ += " return 0;";
} else if (is_struct) {
const auto *struct_def = field.value.type.struct_def;
code_.SetValue("INPUT_TYPE",
GenTypeGet(field.value.type, "", "", "", false));
code_ +=
" int KeyCompareWithValue(const {{INPUT_TYPE}} &_{{FIELD_NAME}}) "
"const {";
code_ += " const auto &lhs_{{FIELD_NAME}} = {{FIELD_NAME}}();";
code_ += " const auto &rhs_{{FIELD_NAME}} = _{{FIELD_NAME}};";
GenComparatorForStruct(*struct_def, 4,
"lhs_" + code_.GetValue("FIELD_NAME"),
"rhs_" + code_.GetValue("FIELD_NAME"));
code_ += " return 0;";

} else {
FLATBUFFERS_ASSERT(IsScalar(field.value.type.base_type));
auto type = GenTypeBasic(field.value.type, false);
Expand Down
9 changes: 6 additions & 3 deletions src/idl_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ CheckedError Parser::ParseField(StructDef &struct_def) {
ECHECK(ParseType(type));

if (struct_def.fixed) {
if (IsIncompleteStruct(type) ||
if (IsIncompleteStruct(type) ||
(IsArray(type) && IsIncompleteStruct(type.VectorType()))) {
std::string type_name = IsArray(type) ? type.VectorType().struct_def->name : type.struct_def->name;
return Error(std::string("Incomplete type in struct is not allowed, type name: ") + type_name);
Expand Down Expand Up @@ -1072,8 +1072,11 @@ CheckedError Parser::ParseField(StructDef &struct_def) {
if (field->key) {
if (struct_def.has_key) return Error("only one field may be set as 'key'");
struct_def.has_key = true;
auto is_valid = IsScalar(type.base_type) || IsString(type);
if (IsArray(type)) { is_valid |= IsScalar(type.VectorType().base_type); }
auto is_valid = IsScalar(type.base_type) || IsString(type) || IsStruct(type);
if (IsArray(type)) {
is_valid |=
IsScalar(type.VectorType().base_type) || IsStruct(type.VectorType());
}
if (!is_valid) {
return Error(
"'key' field must be string, scalar type or fixed size array of "
Expand Down
28 changes: 28 additions & 0 deletions tests/key_field/key_field_sample.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,40 @@ struct Bar {
b: uint8;
}

struct Color {
rgb: [float:3] (key);
tag: uint8;
}

struct Apple {
tag: uint8;
color: Color(key);
}

struct Fruit {
a: Apple (key);
b: uint8;
}

struct Rice {
origin: [uint8:3];
quantity: uint32;
}

struct Grain {
a: [Rice:3] (key);
tag: uint8;
}

table FooTable {
a: int;
b: int;
c: string (key);
d: [Baz];
e: [Bar];
f: [Apple];
g: [Fruit];
h: [Grain];
}
root_type FooTable;

Loading