diff --git a/cpp/src/arrow/parquet/parquet-schema-test.cc b/cpp/src/arrow/parquet/parquet-schema-test.cc index e2280f41189ef..8de739491b56f 100644 --- a/cpp/src/arrow/parquet/parquet-schema-test.cc +++ b/cpp/src/arrow/parquet/parquet-schema-test.cc @@ -161,6 +161,81 @@ TEST_F(TestConvertParquetSchema, UnsupportedThings) { } } +class TestConvertArrowSchema : public ::testing::Test { + public: + virtual void SetUp() {} + + void CheckFlatSchema(const std::vector& nodes) { + NodePtr schema_node = GroupNode::Make("schema", Repetition::REPEATED, nodes); + const GroupNode* expected_schema_node = + static_cast(schema_node.get()); + const GroupNode* result_schema_node = + static_cast(result_schema_->schema().get()); + + ASSERT_EQ(expected_schema_node->field_count(), result_schema_node->field_count()); + + for (int i = 0; i < expected_schema_node->field_count(); i++) { + auto lhs = result_schema_node->field(i); + auto rhs = expected_schema_node->field(i); + EXPECT_TRUE(lhs->Equals(rhs.get())); + } + } + + Status ConvertSchema(const std::vector>& fields) { + arrow_schema_ = std::make_shared(fields); + return ToParquetSchema(arrow_schema_.get(), &result_schema_); + } + + protected: + std::shared_ptr arrow_schema_; + std::shared_ptr<::parquet::SchemaDescriptor> result_schema_; +}; + +TEST_F(TestConvertArrowSchema, ParquetFlatPrimitives) { + std::vector parquet_fields; + std::vector> arrow_fields; + + parquet_fields.push_back( + PrimitiveNode::Make("boolean", Repetition::REQUIRED, ParquetType::BOOLEAN)); + arrow_fields.push_back(std::make_shared("boolean", BOOL, false)); + + parquet_fields.push_back( + PrimitiveNode::Make("int32", Repetition::REQUIRED, ParquetType::INT32)); + arrow_fields.push_back(std::make_shared("int32", INT32, false)); + + parquet_fields.push_back( + PrimitiveNode::Make("int64", Repetition::REQUIRED, ParquetType::INT64)); + arrow_fields.push_back(std::make_shared("int64", INT64, false)); + + parquet_fields.push_back( + PrimitiveNode::Make("float", Repetition::OPTIONAL, ParquetType::FLOAT)); + arrow_fields.push_back(std::make_shared("float", FLOAT)); + + parquet_fields.push_back( + PrimitiveNode::Make("double", Repetition::OPTIONAL, ParquetType::DOUBLE)); + arrow_fields.push_back(std::make_shared("double", DOUBLE)); + + // TODO: String types need to be clarified a bit more in the Arrow spec + parquet_fields.push_back(PrimitiveNode::Make( + "string", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, LogicalType::UTF8)); + arrow_fields.push_back(std::make_shared("string", UTF8)); + + ASSERT_OK(ConvertSchema(arrow_fields)); + + CheckFlatSchema(parquet_fields); +} + +TEST_F(TestConvertArrowSchema, ParquetFlatDecimals) { + std::vector parquet_fields; + std::vector> arrow_fields; + + // TODO: Test Decimal Arrow -> Parquet conversion + + ASSERT_OK(ConvertSchema(arrow_fields)); + + CheckFlatSchema(parquet_fields); +} + TEST(TestNodeConversion, DateAndTime) {} } // namespace parquet diff --git a/cpp/src/arrow/parquet/schema.cc b/cpp/src/arrow/parquet/schema.cc index 560e28374066b..214c764f08b6e 100644 --- a/cpp/src/arrow/parquet/schema.cc +++ b/cpp/src/arrow/parquet/schema.cc @@ -17,13 +17,18 @@ #include "arrow/parquet/schema.h" +#include #include #include "parquet/api/schema.h" +#include "parquet/exception.h" #include "arrow/types/decimal.h" +#include "arrow/types/string.h" #include "arrow/util/status.h" +using parquet::ParquetException; +using parquet::Repetition; using parquet::schema::Node; using parquet::schema::NodePtr; using parquet::schema::GroupNode; @@ -36,6 +41,11 @@ namespace arrow { namespace parquet { +#define PARQUET_CATCH_NOT_OK(s) \ + try { \ + (s); \ + } catch (const ParquetException& e) { return Status::Invalid(e.what()); } + const auto BOOL = std::make_shared(); const auto UINT8 = std::make_shared(); const auto INT32 = std::make_shared(); @@ -182,6 +192,126 @@ Status FromParquetSchema( return Status::OK(); } +Status StructToNode(const std::shared_ptr& type, const std::string& name, + bool nullable, NodePtr* out) { + Repetition::type repetition = Repetition::REQUIRED; + if (nullable) { repetition = Repetition::OPTIONAL; } + + std::vector children(type->num_children()); + for (int i = 0; i < type->num_children(); i++) { + RETURN_NOT_OK(FieldToNode(type->child(i), &children[i])); + } + + *out = GroupNode::Make(name, repetition, children); + return Status::OK(); +} + +Status FieldToNode(const std::shared_ptr& field, NodePtr* out) { + LogicalType::type logical_type = LogicalType::NONE; + ParquetType::type type; + Repetition::type repetition = Repetition::REQUIRED; + if (field->nullable) { repetition = Repetition::OPTIONAL; } + int length = -1; + + switch (field->type->type) { + // TODO: + // case Type::NA: + // break; + case Type::BOOL: + type = ParquetType::BOOLEAN; + break; + case Type::UINT8: + type = ParquetType::INT32; + logical_type = LogicalType::UINT_8; + break; + case Type::INT8: + type = ParquetType::INT32; + logical_type = LogicalType::INT_8; + break; + case Type::UINT16: + type = ParquetType::INT32; + logical_type = LogicalType::UINT_16; + break; + case Type::INT16: + type = ParquetType::INT32; + logical_type = LogicalType::INT_16; + break; + case Type::UINT32: + type = ParquetType::INT32; + logical_type = LogicalType::UINT_32; + break; + case Type::INT32: + type = ParquetType::INT32; + break; + case Type::UINT64: + type = ParquetType::INT64; + logical_type = LogicalType::UINT_64; + break; + case Type::INT64: + type = ParquetType::INT64; + break; + case Type::FLOAT: + type = ParquetType::FLOAT; + break; + case Type::DOUBLE: + type = ParquetType::DOUBLE; + break; + case Type::CHAR: + type = ParquetType::FIXED_LEN_BYTE_ARRAY; + logical_type = LogicalType::UTF8; + length = static_cast(field->type.get())->size; + break; + case Type::STRING: + type = ParquetType::BYTE_ARRAY; + logical_type = LogicalType::UTF8; + break; + case Type::BINARY: + type = ParquetType::BYTE_ARRAY; + break; + case Type::DATE: + type = ParquetType::INT32; + logical_type = LogicalType::DATE; + break; + case Type::TIMESTAMP: + type = ParquetType::INT64; + logical_type = LogicalType::TIMESTAMP_MILLIS; + break; + case Type::TIMESTAMP_DOUBLE: + type = ParquetType::INT64; + // This is specified as seconds since the UNIX epoch + // TODO: Converted type in Parquet? + // logical_type = LogicalType::TIMESTAMP_MILLIS; + break; + case Type::TIME: + type = ParquetType::INT64; + logical_type = LogicalType::TIME_MILLIS; + break; + case Type::STRUCT: { + auto struct_type = std::static_pointer_cast(field->type); + return StructToNode(struct_type, field->name, field->nullable, out); + } break; + default: + // TODO: LIST, DENSE_UNION, SPARE_UNION, JSON_SCALAR, DECIMAL, DECIMAL_TEXT, VARCHAR + return Status::NotImplemented("unhandled type"); + } + *out = PrimitiveNode::Make(field->name, repetition, type, logical_type, length); + return Status::OK(); +} + +Status ToParquetSchema( + const Schema* arrow_schema, std::shared_ptr<::parquet::SchemaDescriptor>* out) { + std::vector nodes(arrow_schema->num_fields()); + for (int i = 0; i < arrow_schema->num_fields(); i++) { + RETURN_NOT_OK(FieldToNode(arrow_schema->field(i), &nodes[i])); + } + + NodePtr schema = GroupNode::Make("schema", Repetition::REPEATED, nodes); + *out = std::make_shared<::parquet::SchemaDescriptor>(); + PARQUET_CATCH_NOT_OK((*out)->Init(schema)); + + return Status::OK(); +} + } // namespace parquet } // namespace arrow diff --git a/cpp/src/arrow/parquet/schema.h b/cpp/src/arrow/parquet/schema.h index a44a9a4b6a892..bfc7d21138154 100644 --- a/cpp/src/arrow/parquet/schema.h +++ b/cpp/src/arrow/parquet/schema.h @@ -36,6 +36,11 @@ Status NodeToField(const ::parquet::schema::NodePtr& node, std::shared_ptr* out); +Status FieldToNode(const std::shared_ptr& field, ::parquet::schema::NodePtr* out); + +Status ToParquetSchema( + const Schema* arrow_schema, std::shared_ptr<::parquet::SchemaDescriptor>* out); + } // namespace parquet } // namespace arrow