Skip to content

Commit

Permalink
Refactor LogicalType for Parquet (#14264)
Browse files Browse the repository at this point in the history
Continuation of #14097, this PR refactors the LogicalType struct to use the new way of treating unions defined in the parquet thrift (more enum like than struct like).

Authors:
  - Ed Seidl (https://github.com/etseidl)
  - Vukasin Milovanovic (https://github.com/vuule)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Nghia Truong (https://github.com/ttnghia)

URL: #14264
  • Loading branch information
etseidl authored Oct 20, 2023
1 parent e7c6365 commit 253f6a6
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 303 deletions.
95 changes: 22 additions & 73 deletions cpp/src/io/parquet/compact_protocol_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,61 +339,6 @@ struct parquet_field_struct_list : public parquet_field_list<T> {
}
};

// TODO(ets): replace current union handling (which mirrors thrift) to use std::optional fields
// in a struct
/**
* @brief Functor to read a union member from CompactProtocolReader
*
* @tparam is_empty True if tparam `T` type is empty type, else false.
*
* @return True if field types mismatch or if the process of reading a
* union member fails
*/
template <typename T, bool is_empty = false>
class ParquetFieldUnionFunctor : public parquet_field {
bool& is_set;
T& val;

public:
ParquetFieldUnionFunctor(int f, bool& b, T& v) : parquet_field(f), is_set(b), val(v) {}

inline bool operator()(CompactProtocolReader* cpr, int field_type)
{
if (field_type != ST_FLD_STRUCT) {
return true;
} else {
is_set = true;
return !cpr->read(&val);
}
}
};

template <typename T>
class ParquetFieldUnionFunctor<T, true> : public parquet_field {
bool& is_set;
T& val;

public:
ParquetFieldUnionFunctor(int f, bool& b, T& v) : parquet_field(f), is_set(b), val(v) {}

inline bool operator()(CompactProtocolReader* cpr, int field_type)
{
if (field_type != ST_FLD_STRUCT) {
return true;
} else {
is_set = true;
cpr->skip_struct_field(field_type);
return false;
}
}
};

template <typename T>
ParquetFieldUnionFunctor<T, std::is_empty_v<T>> ParquetFieldUnion(int f, bool& b, T& v)
{
return ParquetFieldUnionFunctor<T, std::is_empty_v<T>>(f, b, v);
}

/**
* @brief Functor to read a binary from CompactProtocolReader
*
Expand Down Expand Up @@ -595,34 +540,38 @@ bool CompactProtocolReader::read(FileMetaData* f)

bool CompactProtocolReader::read(SchemaElement* s)
{
using optional_converted_type =
parquet_field_optional<ConvertedType, parquet_field_enum<ConvertedType>>;
using optional_logical_type =
parquet_field_optional<LogicalType, parquet_field_struct<LogicalType>>;
auto op = std::make_tuple(parquet_field_enum<Type>(1, s->type),
parquet_field_int32(2, s->type_length),
parquet_field_enum<FieldRepetitionType>(3, s->repetition_type),
parquet_field_string(4, s->name),
parquet_field_int32(5, s->num_children),
parquet_field_enum<ConvertedType>(6, s->converted_type),
optional_converted_type(6, s->converted_type),
parquet_field_int32(7, s->decimal_scale),
parquet_field_int32(8, s->decimal_precision),
parquet_field_optional<int32_t, parquet_field_int32>(9, s->field_id),
parquet_field_struct(10, s->logical_type));
optional_logical_type(10, s->logical_type));
return function_builder(this, op);
}

bool CompactProtocolReader::read(LogicalType* l)
{
auto op =
std::make_tuple(ParquetFieldUnion(1, l->isset.STRING, l->STRING),
ParquetFieldUnion(2, l->isset.MAP, l->MAP),
ParquetFieldUnion(3, l->isset.LIST, l->LIST),
ParquetFieldUnion(4, l->isset.ENUM, l->ENUM),
ParquetFieldUnion(5, l->isset.DECIMAL, l->DECIMAL), // read the struct
ParquetFieldUnion(6, l->isset.DATE, l->DATE),
ParquetFieldUnion(7, l->isset.TIME, l->TIME), // read the struct
ParquetFieldUnion(8, l->isset.TIMESTAMP, l->TIMESTAMP), // read the struct
ParquetFieldUnion(10, l->isset.INTEGER, l->INTEGER), // read the struct
ParquetFieldUnion(11, l->isset.UNKNOWN, l->UNKNOWN),
ParquetFieldUnion(12, l->isset.JSON, l->JSON),
ParquetFieldUnion(13, l->isset.BSON, l->BSON));
auto op = std::make_tuple(
parquet_field_union_enumerator(1, l->type),
parquet_field_union_enumerator(2, l->type),
parquet_field_union_enumerator(3, l->type),
parquet_field_union_enumerator(4, l->type),
parquet_field_union_struct<LogicalType::Type, DecimalType>(5, l->type, l->decimal_type),
parquet_field_union_enumerator(6, l->type),
parquet_field_union_struct<LogicalType::Type, TimeType>(7, l->type, l->time_type),
parquet_field_union_struct<LogicalType::Type, TimestampType>(8, l->type, l->timestamp_type),
parquet_field_union_struct<LogicalType::Type, IntType>(10, l->type, l->int_type),
parquet_field_union_enumerator(11, l->type),
parquet_field_union_enumerator(12, l->type),
parquet_field_union_enumerator(13, l->type));
return function_builder(this, op);
}

Expand All @@ -648,9 +597,9 @@ bool CompactProtocolReader::read(TimestampType* t)

bool CompactProtocolReader::read(TimeUnit* u)
{
auto op = std::make_tuple(ParquetFieldUnion(1, u->isset.MILLIS, u->MILLIS),
ParquetFieldUnion(2, u->isset.MICROS, u->MICROS),
ParquetFieldUnion(3, u->isset.NANOS, u->NANOS));
auto op = std::make_tuple(parquet_field_union_enumerator(1, u->type),
parquet_field_union_enumerator(2, u->type),
parquet_field_union_enumerator(3, u->type));
return function_builder(this, op);
}

Expand Down
81 changes: 37 additions & 44 deletions cpp/src/io/parquet/compact_protocol_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "compact_protocol_writer.hpp"

#include <cudf/utilities/error.hpp>

namespace cudf::io::parquet::detail {

/**
Expand Down Expand Up @@ -46,13 +48,11 @@ size_t CompactProtocolWriter::write(DecimalType const& decimal)
size_t CompactProtocolWriter::write(TimeUnit const& time_unit)
{
CompactProtocolFieldWriter c(*this);
auto const isset = time_unit.isset;
if (isset.MILLIS) {
c.field_struct(1, time_unit.MILLIS);
} else if (isset.MICROS) {
c.field_struct(2, time_unit.MICROS);
} else if (isset.NANOS) {
c.field_struct(3, time_unit.NANOS);
switch (time_unit.type) {
case TimeUnit::MILLIS:
case TimeUnit::MICROS:
case TimeUnit::NANOS: c.field_empty_struct(time_unit.type); break;
default: CUDF_FAIL("Trying to write an invalid TimeUnit " + std::to_string(time_unit.type));
}
return c.value();
}
Expand Down Expand Up @@ -84,31 +84,29 @@ size_t CompactProtocolWriter::write(IntType const& integer)
size_t CompactProtocolWriter::write(LogicalType const& logical_type)
{
CompactProtocolFieldWriter c(*this);
auto const isset = logical_type.isset;
if (isset.STRING) {
c.field_struct(1, logical_type.STRING);
} else if (isset.MAP) {
c.field_struct(2, logical_type.MAP);
} else if (isset.LIST) {
c.field_struct(3, logical_type.LIST);
} else if (isset.ENUM) {
c.field_struct(4, logical_type.ENUM);
} else if (isset.DECIMAL) {
c.field_struct(5, logical_type.DECIMAL);
} else if (isset.DATE) {
c.field_struct(6, logical_type.DATE);
} else if (isset.TIME) {
c.field_struct(7, logical_type.TIME);
} else if (isset.TIMESTAMP) {
c.field_struct(8, logical_type.TIMESTAMP);
} else if (isset.INTEGER) {
c.field_struct(10, logical_type.INTEGER);
} else if (isset.UNKNOWN) {
c.field_struct(11, logical_type.UNKNOWN);
} else if (isset.JSON) {
c.field_struct(12, logical_type.JSON);
} else if (isset.BSON) {
c.field_struct(13, logical_type.BSON);
switch (logical_type.type) {
case LogicalType::STRING:
case LogicalType::MAP:
case LogicalType::LIST:
case LogicalType::ENUM:
case LogicalType::DATE:
case LogicalType::UNKNOWN:
case LogicalType::JSON:
case LogicalType::BSON: c.field_empty_struct(logical_type.type); break;
case LogicalType::DECIMAL:
c.field_struct(LogicalType::DECIMAL, logical_type.decimal_type.value());
break;
case LogicalType::TIME:
c.field_struct(LogicalType::TIME, logical_type.time_type.value());
break;
case LogicalType::TIMESTAMP:
c.field_struct(LogicalType::TIMESTAMP, logical_type.timestamp_type.value());
break;
case LogicalType::INTEGER:
c.field_struct(LogicalType::INTEGER, logical_type.int_type.value());
break;
default:
CUDF_FAIL("Trying to write an invalid LogicalType " + std::to_string(logical_type.type));
}
return c.value();
}
Expand All @@ -124,20 +122,15 @@ size_t CompactProtocolWriter::write(SchemaElement const& s)
c.field_string(4, s.name);

if (s.type == UNDEFINED_TYPE) { c.field_int(5, s.num_children); }
if (s.converted_type != UNKNOWN) {
c.field_int(6, s.converted_type);
if (s.converted_type.has_value()) {
c.field_int(6, s.converted_type.value());
if (s.converted_type == DECIMAL) {
c.field_int(7, s.decimal_scale);
c.field_int(8, s.decimal_precision);
}
}
if (s.field_id) { c.field_int(9, s.field_id.value()); }
auto const isset = s.logical_type.isset;
// TODO: add handling for all logical types
// if (isset.STRING or isset.MAP or isset.LIST or isset.ENUM or isset.DECIMAL or isset.DATE or
// isset.TIME or isset.TIMESTAMP or isset.INTEGER or isset.UNKNOWN or isset.JSON or isset.BSON)
// {
if (isset.TIMESTAMP or isset.TIME) { c.field_struct(10, s.logical_type); }
if (s.field_id.has_value()) { c.field_int(9, s.field_id.value()); }
if (s.logical_type.has_value()) { c.field_struct(10, s.logical_type.value()); }
return c.value();
}

Expand Down Expand Up @@ -223,9 +216,9 @@ size_t CompactProtocolWriter::write(OffsetIndex const& s)
size_t CompactProtocolWriter::write(ColumnOrder const& co)
{
CompactProtocolFieldWriter c(*this);
switch (co) {
case ColumnOrder::TYPE_ORDER: c.field_empty_struct(1); break;
default: break;
switch (co.type) {
case ColumnOrder::TYPE_ORDER: c.field_empty_struct(co.type); break;
default: CUDF_FAIL("Trying to write an invalid ColumnOrder " + std::to_string(co.type));
}
return c.value();
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/io/parquet/page_decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,8 @@ inline __device__ bool setupLocalPageInfo(page_state_s* const s,
units = cudf::timestamp_ms::period::den;
} else if (s->col.converted_type == TIMESTAMP_MICROS) {
units = cudf::timestamp_us::period::den;
} else if (s->col.logical_type.TIMESTAMP.unit.isset.NANOS) {
} else if (s->col.logical_type.has_value() and
s->col.logical_type->is_timestamp_nanos()) {
units = cudf::timestamp_ns::period::den;
}
if (units and units != s->col.ts_clock_rate) {
Expand Down
Loading

0 comments on commit 253f6a6

Please sign in to comment.