Skip to content

Commit

Permalink
verify column types match feature types
Browse files Browse the repository at this point in the history
Signed-off-by: Oleksii Moskalenko <moskalenko.alexey@gmail.com>
  • Loading branch information
pyalex committed Oct 12, 2020
1 parent 6b59ec0 commit aaa39aa
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package feast.ingestion

import feast.ingestion.sources.bq.BigQueryReader
import feast.ingestion.sources.file.FileReader
import feast.ingestion.validation.RowValidator
import feast.ingestion.validation.{RowValidator, TypeCheck}
import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.sql.functions.col

Expand Down Expand Up @@ -56,6 +56,12 @@ object BatchPipeline extends BasePipeline {

val projected = input.select(projection: _*).cache()

TypeCheck.allTypesMatch(projected.schema, featureTable) match {
case Left(error) =>
throw new RuntimeException(s"Dataframes columns don't match expected feature types: $error")
case _ => ()
}

val validRows = projected
.filter(validator.checkAll)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,18 @@ import feast.ingestion.registry.proto.ProtoRegistryFactory
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.functions.udf
import feast.ingestion.utils.ProtoReflection
import feast.ingestion.validation.RowValidator
import feast.ingestion.validation.{RowValidator, TypeCheck}
import org.apache.spark.sql.streaming.StreamingQuery

/**
* Streaming pipeline (currently in micro-batches mode only, since we need to have multiple sinks: redis & deadletters).
* Flow:
* 1. Read from streaming source (currently only Kafka)
* 2. Parse bytes from streaming source into Row with schema inferenced from provided class (Protobuf)
* 3. Map columns according to provided mapping rules
* 4. Validate
* 5. (In batches) store to redis valid rows / write to deadletter (parquet) invalid
*/
object StreamingPipeline extends BasePipeline with Serializable {
override def createPipeline(
sparkSession: SparkSession,
Expand Down Expand Up @@ -52,6 +61,12 @@ object StreamingPipeline extends BasePipeline with Serializable {
.select("features.*")
.select(projection: _*)

TypeCheck.allTypesMatch(projected.schema, featureTable) match {
case Left(error) =>
throw new RuntimeException(s"Dataframes columns don't match expected feature types: $error")
case _ => ()
}

val query = projected.writeStream
.foreachBatch { (batchDF: DataFrame, batchID: Long) =>
batchDF.persist()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2020 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.ingestion.validation

import feast.ingestion.FeatureTable
import feast.proto.types.ValueProto.ValueType
import org.apache.spark.sql.types._

object TypeCheck {
def typesMatch(defType: ValueType.Enum, columnType: DataType): Boolean =
(defType, columnType) match {
case (ValueType.Enum.BOOL, BooleanType) => true
case (ValueType.Enum.INT32, IntegerType) => true
case (ValueType.Enum.INT64, LongType) => true
case (ValueType.Enum.FLOAT, FloatType) => true
case (ValueType.Enum.DOUBLE, DoubleType) => true
case (ValueType.Enum.STRING, StringType) => true
case (ValueType.Enum.BYTES, BinaryType) => true
case (ValueType.Enum.BOOL_LIST, ArrayType(_: BooleanType, _)) => true
case (ValueType.Enum.INT32_LIST, ArrayType(_: IntegerType, _)) => true
case (ValueType.Enum.INT64_LIST, ArrayType(_: LongType, _)) => true
case (ValueType.Enum.FLOAT_LIST, ArrayType(_: FloatType, _)) => true
case (ValueType.Enum.DOUBLE_LIST, ArrayType(_: DoubleType, _)) => true
case (ValueType.Enum.STRING_LIST, ArrayType(_: StringType, _)) => true
case (ValueType.Enum.BYTES_LIST, ArrayType(_: BinaryType, _)) => true
case _ => false
}

/**
* Verify whether types declared in FeatureTable match correspondent columns
* @param schema Spark's dataframe columns
* @param featureTable definition of expected schema
* @return true if all match
*/
def allTypesMatch(schema: StructType, featureTable: FeatureTable): Either[String, Boolean] = {
val typeByField =
(featureTable.entities ++ featureTable.features).map(f => f.name -> f.`type`).toMap

schema.fields
.map(f =>
if (typeByField.contains(f.name) && !typesMatch(typeByField(f.name), f.dataType)) {
Left(
s"Feature ${f.name} has different type ${f.dataType} instead of expected ${typeByField(f.name)}"
)
} else Right(true)
)
.foldLeft[Either[String, Boolean]](Right(true)) {
case (Left(error), _) => Left(error)
case (Right(_), Left(error)) => Left(error)
case _ => Right(true)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class StreamingPipelineIT extends SparkSpec with ForAllTestContainer {
name = "driver-fs",
project = "default",
entities = Seq(
Field("s2_id", ValueType.Enum.INT32),
Field("s2_id", ValueType.Enum.INT64),
Field("vehicle_type", ValueType.Enum.STRING)
),
features = Seq(
Expand Down Expand Up @@ -256,4 +256,25 @@ class StreamingPipelineIT extends SparkSpec with ForAllTestContainer {
)
)
}

"Expected feature types" should "match source types" in new Scope {
val configWithKafka = config.copy(
source = kafkaSource,
featureTable = FeatureTable(
name = "driver-fs",
project = "default",
entities = Seq(
Field("s2_id", ValueType.Enum.STRING),
Field("vehicle_type", ValueType.Enum.INT32)
),
features = Seq(
Field("unique_drivers", ValueType.Enum.FLOAT)
)
)
)

assertThrows[RuntimeException] {
StreamingPipeline.createPipeline(sparkSession, configWithKafka).get
}
}
}

0 comments on commit aaa39aa

Please sign in to comment.