Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for array types #56

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ The Dataset extension performs the following validations:
* The Dataset has no columns with unknown types, unless `excludeColumnsWithUnknownTypes` is set to true

The Dataset extension performs the following transformations:
* Drops all columns of complex datatypes such as `StructType`, `MapType` or `ArrayType` as they
* Drops all columns of complex datatypes such as `StructType` or `MapType` as they
are not supported by `DruidSource`. This is only done if `excludeColumnsWithUnknownTypes` is set to true, otherwise validation has already failed.
* `ArrayType` is supported with `StringType`, `LongType` and `DoubleType`
* Converts `Date`/`Timestamp` type columns to `String`, except for the `time_column`
- See [Druid Docs / Data types](https://druid.apache.org/docs/latest/querying/sql.html#standard-types)
* Adds a new column `__PARTITION_TIME__` whose value is based on `time_column` column and the given [segment granularity](#segment-granularity)
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/com/rovio/ingest/DataSegmentCommitMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.fasterxml.jackson.databind.InjectableValues;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.jsontype.NamedType;
import org.apache.druid.guice.NestedDataModule;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchModule;
Expand All @@ -29,6 +30,8 @@
import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchBuildComplexMetricSerde;
import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchMergeComplexMetricSerde;
import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchModule;
import org.apache.druid.segment.DefaultColumnFormatConfig;
import org.apache.druid.segment.nested.NestedDataComplexTypeSerde;
import org.apache.druid.segment.serde.ComplexMetrics;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
Expand All @@ -52,6 +55,7 @@ public class DataSegmentCommitMessage implements WriterCommitMessage {
// ExpressionMacroTable is injected in AggregatorFactories.
// However, ExprMacro are not actually required as the DataSource is write-only.
.addValue(ExprMacroTable.class, ExprMacroTable.nil())
.addValue(DefaultColumnFormatConfig.class, new DefaultColumnFormatConfig(null))
// PruneLoadSpecHolder are injected in DataSegment.
.addValue(DataSegment.PruneSpecsHolder.class, DataSegment.PruneSpecsHolder.DEFAULT);

Expand All @@ -61,12 +65,14 @@ public class DataSegmentCommitMessage implements WriterCommitMessage {

MAPPER.setTimeZone(TimeZone.getTimeZone("UTC"));

new NestedDataModule().getJacksonModules().forEach(MAPPER::registerModule);
new SketchModule().getJacksonModules().forEach(MAPPER::registerModule);
new HllSketchModule().getJacksonModules().forEach(MAPPER::registerModule);
new KllSketchModule().getJacksonModules().forEach(MAPPER::registerModule);
new DoublesSketchModule().getJacksonModules().forEach(MAPPER::registerModule);
new ArrayOfDoublesSketchModule().getJacksonModules().forEach(MAPPER::registerModule);

NestedDataModule.registerHandlersAndSerde();
HllSketchModule.registerSerde();
KllSketchModule.registerSerde();
DoublesSketchModule.registerSerde();
Expand All @@ -75,6 +81,7 @@ public class DataSegmentCommitMessage implements WriterCommitMessage {
ComplexMetrics.registerSerde("arrayOfDoublesSketch", new ArrayOfDoublesSketchMergeComplexMetricSerde());
ComplexMetrics.registerSerde("arrayOfDoublesSketchMerge", new ArrayOfDoublesSketchMergeComplexMetricSerde());
ComplexMetrics.registerSerde("arrayOfDoublesSketchBuild", new ArrayOfDoublesSketchBuildComplexMetricSerde());
ComplexMetrics.registerSerde(NestedDataComplexTypeSerde.TYPE_NAME, NestedDataComplexTypeSerde.INSTANCE);
}


Expand Down
39 changes: 39 additions & 0 deletions src/main/java/com/rovio/ingest/TaskDataWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
*/
package com.rovio.ingest;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.rovio.ingest.model.Field;
import com.rovio.ingest.model.FieldType;
import com.rovio.ingest.model.SegmentSpec;
import com.rovio.ingest.util.ReflectionUtils;
import com.rovio.ingest.util.SegmentStorageUpdater;
Expand All @@ -38,6 +40,7 @@
import org.apache.druid.segment.indexing.RealtimeTuningConfig;
import org.apache.druid.segment.loading.DataSegmentKiller;
import org.apache.druid.segment.loading.DataSegmentPusher;
import org.apache.druid.segment.nested.StructuredData;
import org.apache.druid.segment.realtime.FireDepartmentMetrics;
import org.apache.druid.segment.realtime.appenderator.Appenderator;
import org.apache.druid.segment.realtime.appenderator.DefaultOfflineAppenderatorFactory;
Expand All @@ -50,6 +53,7 @@
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.DataType;
Expand All @@ -63,6 +67,7 @@
import java.io.IOException;
import java.time.LocalDate;
import java.time.ZoneOffset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -217,6 +222,40 @@ private Map<String, Object> parse(InternalRow record) {
// Convert to Java String as Spark return UTF8String which is not compatible with Druid sketches.
value = value.toString();
}
if (value != null && segmentSpec.getComplexDimensionColumns().contains(columnName) && sqlType == DataTypes.StringType) {
try {
value = MAPPER.readValue(value.toString(), StructuredData.class);
} catch (JsonProcessingException e) {
value = null;
}
}
if (value != null && field.getFieldType() == FieldType.ARRAY_OF_STRING) {
ArrayData arrayData = record.getArray(field.getOrdinal());
int arraySize = arrayData.numElements();
List<String> valueArrayOfString = new ArrayList<>(arraySize);
for (int i = 0; i < arraySize; i++) {
valueArrayOfString.add(arrayData.get(i, DataTypes.StringType).toString());
}
value = valueArrayOfString;
}
if (value != null && field.getFieldType()== FieldType.ARRAY_OF_DOUBLE) {
ArrayData arrayData = record.getArray(field.getOrdinal());
int arraySize = arrayData.numElements();
Double[] valueArrayOfFloat = new Double[arraySize];
for (int i = 0; i < arraySize; i++) {
valueArrayOfFloat[i] = (Double) arrayData.get(i, DataTypes.DoubleType);
}
value = valueArrayOfFloat;
}
if (value != null && field.getFieldType()== FieldType.ARRAY_OF_LONG) {
ArrayData arrayData = record.getArray(field.getOrdinal());
int arraySize = arrayData.numElements();
Long[] valueArrayOfLong = new Long[arraySize];
for (int i = 0; i < arraySize; i++) {
valueArrayOfLong[i] = (Long) arrayData.get(i, DataTypes.LongType);
}
value = valueArrayOfLong;
}
map.put(columnName, value);
}
}
Expand Down
35 changes: 34 additions & 1 deletion src/main/java/com/rovio/ingest/model/FieldType.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;

import java.util.Objects;

public enum FieldType {
TIMESTAMP,
DOUBLE,
LONG,
STRING;
STRING,
ARRAY_OF_DOUBLE,
ARRAY_OF_LONG,
ARRAY_OF_STRING;

public static FieldType from(DataType dataType) {
if (isNumericType(dataType)) {
Expand All @@ -41,6 +46,18 @@ public static FieldType from(DataType dataType) {
return STRING;
}

if (isArrayOfNumericType(dataType)) {
return ARRAY_OF_LONG;
}

if (isArrayOfDoubleType(dataType)) {
return ARRAY_OF_DOUBLE;
}

if (isArrayOfStringType(dataType)) {
return ARRAY_OF_STRING;
}

throw new IllegalArgumentException("Unsupported Type " + dataType);
}

Expand All @@ -55,4 +72,20 @@ private static boolean isNumericType(DataType dataType) {
|| dataType == DataTypes.ByteType;
}

private static boolean isArrayOfNumericType(DataType dataType) {
return Objects.equals(dataType, DataTypes.createArrayType(DataTypes.LongType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.IntegerType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.ShortType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.ByteType));
}

private static boolean isArrayOfDoubleType(DataType dataType) {
return Objects.equals(dataType, DataTypes.createArrayType(DataTypes.DoubleType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.FloatType));
}
private static boolean isArrayOfStringType(DataType dataType) {
return Objects.equals(dataType, DataTypes.createArrayType(DataTypes.StringType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.BooleanType));
}

}
21 changes: 19 additions & 2 deletions src/main/java/com/rovio/ingest/model/SegmentSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.segment.AutoTypeColumnSchema;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.indexing.DataSchema;
import org.apache.druid.segment.indexing.granularity.GranularitySpec;
Expand Down Expand Up @@ -66,8 +68,8 @@ public class SegmentSpec implements Serializable {
private final String dimensionsSpec;
private final String metricsSpec;
private final String transformSpec;

private final Set<String> complexMetricColumns;
private final Set<String> complexDimensionColumns;

private SegmentSpec(String dataSource, String timeColumn, String segmentGranularity, String queryGranularity,
List<Field> fields, Field partitionTime, Field partitionNum, boolean rollup,
Expand All @@ -88,6 +90,11 @@ private SegmentSpec(String dataSource, String timeColumn, String segmentGranular
.filter(aggregatorFactory -> aggregatorFactory.getIntermediateType().is(ValueType.COMPLEX))
.flatMap((AggregatorFactory aggregatorFactory) -> aggregatorFactory.requiredFields().stream())
.collect(Collectors.toSet());
this.complexDimensionColumns = getDimensionsSpec().getDimensions()
.stream()
.filter(dimensionSchema -> dimensionSchema.getColumnType() == ColumnType.NESTED_DATA)
.map(DimensionSchema::getName)
.collect(Collectors.toSet());
}

public static SegmentSpec from(String datasource, String timeColumn, List<String> excludedDimensions,
Expand Down Expand Up @@ -127,7 +134,7 @@ public static SegmentSpec from(String datasource, String timeColumn, List<String
fields.stream().noneMatch(f -> f.getFieldType() == FieldType.TIMESTAMP && !f.getName().equals(timeColumn) && !f.getName().equals(PARTITION_TIME_COLUMN_NAME)),
String.format("Schema has another timestamp field other than \"%s\"", timeColumn));

Preconditions.checkArgument(fields.stream().anyMatch(f -> f.getFieldType() == FieldType.STRING),
Preconditions.checkArgument(fields.stream().anyMatch(f -> f.getFieldType() == FieldType.STRING || f.getFieldType() == FieldType.ARRAY_OF_STRING),
"Schema has no dimensions");

Preconditions.checkArgument(!rollup || fields.stream().anyMatch(f -> f.getFieldType() == FieldType.LONG || f.getFieldType() == FieldType.DOUBLE),
Expand Down Expand Up @@ -217,6 +224,12 @@ private ImmutableList<DimensionSchema> getDimensionSchemas() {
builder.add(new DoubleDimensionSchema(fieldName));
} else if (field.getFieldType() == FieldType.TIMESTAMP) {
builder.add(new LongDimensionSchema(fieldName));
} else if (field.getFieldType() == FieldType.ARRAY_OF_STRING) {
builder.add(new AutoTypeColumnSchema(fieldName, ColumnType.STRING_ARRAY));
} else if (field.getFieldType() == FieldType.ARRAY_OF_DOUBLE) {
builder.add(new AutoTypeColumnSchema(fieldName, ColumnType.DOUBLE_ARRAY));
} else if (field.getFieldType() == FieldType.ARRAY_OF_LONG) {
builder.add(new AutoTypeColumnSchema(fieldName, ColumnType.LONG_ARRAY));
}
}
}
Expand Down Expand Up @@ -269,4 +282,8 @@ private AggregatorFactory[] getAggregators() {
public Set<String> getComplexMetricColumns() {
return complexMetricColumns;
}

public Set<String> getComplexDimensionColumns() {
return complexDimensionColumns;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ object DruidDatasetExtensions {
*/
@SerialVersionUID(1L)
implicit class DruidDataset(dataset: Dataset[Row]) extends Serializable {
private val METRIC_TYPES = Array(FloatType, DoubleType, IntegerType, LongType, ShortType, ByteType)
private val DIMENSION_TYPES = Array(StringType, DateType, TimestampType, BooleanType)
private val METRIC_TYPES = Array(FloatType, DoubleType, IntegerType, LongType, ShortType, ByteType, ArrayType(LongType), ArrayType(DoubleType))
private val DIMENSION_TYPES = Array(StringType, DateType, TimestampType, BooleanType, ArrayType(StringType))
private val log = LoggerFactory.getLogger(classOf[DruidDataset])

/**
Expand All @@ -66,7 +66,7 @@ object DruidDatasetExtensions {
* <p>
* The method performs the following transformations:
* <ul>
* <li>Drops all columns of complex datatypes such as `StructType`, `MapType` or `ArrayType` as they are not
* <li>Drops all columns of complex datatypes such as `StructType` or `MapType` as they are not
* supported by `DruidSource`. This is only done if `excludeColumnsWithUnknownTypes` is set to true,
* otherwise validation has already failed.</li>
* <li>Adds a new column `__PARTITION_TIME__` whose value is based on `time_column` column and the given segment
Expand Down
63 changes: 63 additions & 0 deletions src/test/java/com/rovio/ingest/SegmentSpecTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,67 @@ public void shouldSupportMetricsSpecAsJson() {
assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getSegmentGranularity());
assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getQueryGranularity());
}

@Test
public void shouldSupportArrayDimensions() {
StructType schema = new StructType()
.add("updateTime", DataTypes.TimestampType)
.add("user_id", DataTypes.StringType)
.add("countries", DataTypes.createArrayType(DataTypes.StringType));
String metricsSpec = "[]";
SegmentSpec spec = SegmentSpec.from("temp", "updateTime", Collections.emptyList(), "DAY", "DAY", schema, false, metricsSpec);

assertEquals("temp", spec.getDataSchema().getDataSource());
assertEquals("updateTime", spec.getTimeColumn());
List<DimensionSchema> dimensions = spec.getDataSchema().getDimensionsSpec().getDimensions();
assertEquals(2, dimensions.size());
List<String> expected = Arrays.asList("user_id", "countries");
assertTrue(dimensions.stream().allMatch(d -> expected.contains(d.getName())));
assertTrue(dimensions.stream().anyMatch(d -> ValueType.STRING == d.getColumnType().getType() && d.getName().equals("user_id")));
assertTrue(dimensions.stream().anyMatch(d -> ValueType.ARRAY == d.getColumnType().getType() && d.getName().equals("countries")));
assertFalse(spec.getDataSchema().getGranularitySpec().isRollup());

assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getSegmentGranularity());
assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getQueryGranularity());
}

@Test
public void shouldDeserializeDimensionSpec() {
StructType schema = new StructType()
.add("__time", DataTypes.TimestampType)
.add("dim1", DataTypes.StringType)
.add("dim2", DataTypes.LongType)
.add("dim3", DataTypes.createArrayType(DataTypes.LongType))
.add("dim4", DataTypes.StringType)
.add("dim5", DataTypes.createArrayType(DataTypes.StringType))
.add("dim6", DataTypes.createArrayType(DataTypes.DoubleType));
String dimensionsSpec =
"{\"dimensions\": " +
"[{\"type\": \"string\", \"name\": \"dim1\" }," +
"{\"type\": \"long\", \"name\": \"dim2\" }," +
"{\"type\": \"auto\", \"name\": \"dim3\" }," +
"{\"type\": \"json\", \"name\": \"dim4\", \"formatVersion\": 5, \"multiValueHandling\": \"array\", \"createBitmapIndex\": true }," +
"{\"type\": \"string\", \"name\": \"dim5\", \"multiValueHandling\": \"array\", \"createBitmapIndex\": true }," +
"{\"type\": \"double\", \"name\": \"dim6\" }],\n" +
"\"includeAllDimensions\": false,\n" +
"\"useSchemaDiscovery\": false}";
String metricsSpec = "[" +
"{\n" +
" \"type\": \"longSum\",\n" +
" \"name\": \"metric2\",\n" +
" \"fieldName\": \"dim2\",\n" +
" \"expression\": null\n" +
"},\n" +
"{\n" +
" \"type\": \"doubleSum\",\n" +
" \"name\": \"metric6\",\n" +
" \"fieldName\": \"dim6\",\n" +
" \"expression\": null\n" +
"}\n" +
"]";
SegmentSpec spec = SegmentSpec.from("temp", "__time", Collections.singletonList("updateTime"), "DAY", "DAY", schema, false, dimensionsSpec, metricsSpec, null);
List<DimensionSchema> dimensions = spec.getDataSchema().getDimensionsSpec().getDimensions();
assertEquals(6, dimensions.size());
}

}
Loading
Loading