Skip to content

Commit

Permalink
Completed implementation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Sep 12, 2022
1 parent 8f5966b commit 7277a90
Show file tree
Hide file tree
Showing 3 changed files with 497 additions and 197 deletions.
208 changes: 120 additions & 88 deletions src/main/scala/uk/co/gresearch/spark/diff/Diff.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package uk.co.gresearch.spark.diff
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, StringType}
import uk.co.gresearch.spark.diff.DiffMode.DiffMode
import uk.co.gresearch.spark.{backticks, distinctPrefixFor}

import scala.collection.JavaConverters
Expand All @@ -37,89 +36,92 @@ class Differ(options: DiffOptions) {
s"Left column names: ${left.columns.mkString(", ")}\n" +
s"Right column names: ${right.columns.mkString(", ")}")

require(left.columns.length == right.columns.length,
val ignoreColumnsCs = ignoreColumns.map(handleConfiguredCaseSensitivity).toSet
def isIgnoredColumn(column: String): Boolean = !ignoreColumnsCs.contains(handleConfiguredCaseSensitivity(column))
val leftNonIgnored = left.columns.filter(isIgnoredColumn)
val rightNonIgnored = right.columns.filter(isIgnoredColumn)

def notInWithCaseSensitivity(columns: Seq[String])(column: String): Boolean =
!columns.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(column))

val exceptIgnoredColumnsMsg = if (ignoreColumns.nonEmpty) " except ignored columns" else ""

require(leftNonIgnored.length == rightNonIgnored.length,
"The number of columns doesn't match.\n" +
s"Left column names (${left.columns.length}): ${left.columns.mkString(", ")}\n" +
s"Right column names (${right.columns.length}): ${right.columns.mkString(", ")}")
s"Left column names$exceptIgnoredColumnsMsg (${leftNonIgnored.length}): ${leftNonIgnored.mkString(", ")}\n" +
s"Right column names$exceptIgnoredColumnsMsg (${rightNonIgnored.length}): ${rightNonIgnored.mkString(", ")}")

require(left.columns.length > 0, "The schema must not be empty")
require(leftNonIgnored.length > 0, s"The schema$exceptIgnoredColumnsMsg must not be empty")

// column types must match but we ignore the nullability of columns
val leftFields = left.schema.fields.map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
val rightFields = right.schema.fields.map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
val leftFields = left.schema.fields.filter(f => isIgnoredColumn(f.name)).map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
val rightFields = right.schema.fields.filter(f => isIgnoredColumn(f.name)).map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
val leftExtraSchema = leftFields.diff(rightFields)
val rightExtraSchema = rightFields.diff(leftFields)
require(leftExtraSchema.isEmpty && rightExtraSchema.isEmpty,
"The datasets do not have the same schema.\n" +
s"Left extra columns: ${leftExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}\n" +
s"Right extra columns: ${rightExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}")

val columns = left.columns.map(handleConfiguredCaseSensitivity)
val pkColumns = if (idColumns.isEmpty) columns.toList else idColumns.map(handleConfiguredCaseSensitivity)
val nonPkColumns = columns.diff(pkColumns)
val missingIdColumns = pkColumns.diff(columns)
val columns = leftNonIgnored
val pkColumns = if (idColumns.isEmpty) columns.toList else idColumns
val nonPkColumns = columns.filter(notInWithCaseSensitivity(pkColumns))
val missingIdColumns = pkColumns.filter(notInWithCaseSensitivity(columns))
require(missingIdColumns.isEmpty,
s"Some id columns do not exist: ${missingIdColumns.mkString(", ")} missing among ${columns.mkString(", ")}")

require(!pkColumns.contains(handleConfiguredCaseSensitivity(options.diffColumn)),
val missingIgnoreColumns = ignoreColumns.diffCaseSensitivity(left.columns).diffCaseSensitivity(right.columns)
require(missingIgnoreColumns.isEmpty,
s"Some ignore columns do not exist: ${missingIgnoreColumns.mkString(", ")} " +
s"missing among ${(leftNonIgnored ++ rightNonIgnored).distinct.sorted.mkString(", ")}")

require(notInWithCaseSensitivity(pkColumns)(options.diffColumn),
s"The id columns must not contain the diff column name '${options.diffColumn}': " +
s"${pkColumns.mkString(", ")}")
if(Set(DiffMode.LeftSide, DiffMode.RightSide).contains(options.diffMode))
require(!nonPkColumns.contains(options.diffColumn),
require(notInWithCaseSensitivity(nonPkColumns)(options.diffColumn),
s"The non-id columns must not contain the diff column name '${options.diffColumn}': ${nonPkColumns.mkString((", "))}")

require(!options.changeColumn.exists(pkColumns.contains),
require(options.changeColumn.forall(notInWithCaseSensitivity(pkColumns)),
s"The id columns must not contain the change column name '${options.changeColumn.get}': ${pkColumns.mkString((", "))}")
if(Set(DiffMode.LeftSide, DiffMode.RightSide).contains(options.diffMode))
require(!options.changeColumn.exists(nonPkColumns.contains),
require(!options.changeColumn.exists(notInWithCaseSensitivity(nonPkColumns)),
s"The non-id columns must not contain the change column name '${options.changeColumn.get}': ${nonPkColumns.mkString((", "))}")

val nonIdColumns = columns.diff(pkColumns)
val diffValueColumns = getDiffValueColumns(nonIdColumns, options.diffMode)
val diffValueColumns = getDiffColumns(pkColumns, nonPkColumns, left, right, ignoreColumns).map(_._1).diff(pkColumns)

require(!diffValueColumns.contains(handleConfiguredCaseSensitivity(options.diffColumn)),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce the diff column name '${options.diffColumn}': " +
s"${nonIdColumns.mkString(", ")}")
if (Seq(DiffMode.LeftSide, DiffMode.RightSide).contains(options.diffMode)) {
require(notInWithCaseSensitivity(diffValueColumns)(options.diffColumn),
s"The ${if (options.diffMode == DiffMode.LeftSide) "left" else "right"} " +
s"non-id columns must not contain the diff column name '${options.diffColumn}': " +
s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}")

options.changeColumn.foreach( changeColumn =>
require(!diffValueColumns.contains(handleConfiguredCaseSensitivity(changeColumn)),
options.changeColumn.foreach( changeColumn =>
require(notInWithCaseSensitivity(diffValueColumns)(changeColumn),
s"The ${if (options.diffMode == DiffMode.LeftSide) "left" else "right"} " +
s"non-id columns must not contain the change column name '${options.changeColumn.get}': " +
s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}")
)
} else {
require(notInWithCaseSensitivity(diffValueColumns)(options.diffColumn),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce the change column name '${changeColumn}': " +
s"${nonIdColumns.mkString(", ")}")
)

require(diffValueColumns.forall(column => !pkColumns.contains(column)),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce any id column name '${pkColumns.mkString("', '")}': " +
s"${nonIdColumns.mkString(", ")}")
}

/**
* Produces the left and right value columns (non-id columns).
* @param nonIdColumns value column names
* @return left and right diff value column names
*/
private[diff] def getDiffValueColumns(nonIdColumns: Seq[String], diffMode: DiffMode): Seq[String] = {
def prefixColumns(columns: Seq[String])(prefix: String): Seq[String] =
columns.map(column => s"${prefix}_$column")

diffMode match {
case DiffMode.ColumnByColumn =>
Seq(options.leftColumnPrefix, options.rightColumnPrefix)
.flatMap(prefixColumns(nonIdColumns))
.map(handleConfiguredCaseSensitivity)

case DiffMode.SideBySide =>
prefixColumns(nonIdColumns)(options.leftColumnPrefix) ++
prefixColumns(nonIdColumns)(options.rightColumnPrefix)
.map(handleConfiguredCaseSensitivity)
s"must not produce the diff column name '${options.diffColumn}': " +
s"${nonPkColumns.mkString(", ")}")

options.changeColumn.foreach( changeColumn =>
require(notInWithCaseSensitivity(diffValueColumns)(changeColumn),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce the change column name '${changeColumn}': " +
s"${nonPkColumns.mkString(", ")}")
)

case DiffMode.LeftSide | DiffMode.RightSide =>
nonIdColumns
require(diffValueColumns.forall(column => notInWithCaseSensitivity(pkColumns)(column)),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce any id column name '${pkColumns.mkString("', '")}': " +
s"${nonPkColumns.mkString(", ")}")
}
}

Expand All @@ -144,53 +146,83 @@ class Differ(options: DiffOptions) {
)
}

private[diff] def getDiffColumns[T, U](pkColumns: Seq[String], otherColumns: Seq[String],
private[diff] def getDiffColumns[T, U](pkColumns: Seq[String], valueColumns: Seq[String],
left: Dataset[T], right: Dataset[U],
ignoreColumns: Seq[String]): Seq[Column] = {
val idColumns = pkColumns.map(c => coalesce(left(backticks(c)), right(backticks(c))).as(c))
ignoreColumns: Seq[String]): Seq[(String, Column)] = {
val idColumns = pkColumns.map(c => c -> coalesce(left(backticks(c)), right(backticks(c))).as(c))

val leftValueColumns = left.columns.filterIsInCaseSensitivity(valueColumns)
val rightValueColumns = right.columns.filterIsInCaseSensitivity(valueColumns)

val leftNonPkColumns = left.columns.diffCaseSensitivity(pkColumns)
val rightNonPkColumns = right.columns.diffCaseSensitivity(pkColumns)

val leftIgnoredColumns = left.columns.filterIsInCaseSensitivity(ignoreColumns)
val rightIgnoredColumns = right.columns.filterIsInCaseSensitivity(ignoreColumns)

val (leftValues, rightValues) = if (options.sparseMode) {
(
otherColumns.map(c => (c, if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), left(backticks(c))) else left(backticks(c)))).toMap,
otherColumns.map(c => (c, if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), right(backticks(c))) else right(backticks(c)))).toMap
leftNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), left(backticks(c))) else left(backticks(c))))).toMap,
rightNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), right(backticks(c))) else right(backticks(c))))).toMap
)
} else {
(
otherColumns.map(c => (c, left(backticks(c)))).toMap,
otherColumns.map(c => (c, right(backticks(c)))).toMap
leftNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> left(backticks(c)))).toMap,
rightNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> right(backticks(c)))).toMap,
)
}

val valueColumns = options.diffMode match {
def alias(prefix: Option[String], values: Map[String, (String, Column)])(name: String): (String, Column) = {
values(handleConfiguredCaseSensitivity(name)) match {
case (name, column) =>
val alias = prefix.map(p => s"${p}_$name").getOrElse(name)
alias -> column.as(alias)
}
}

def aliasLeft(name: String): (String, Column) = alias(Some(options.leftColumnPrefix), leftValues)(name)

def aliasRight(name: String): (String, Column) = alias(Some(options.rightColumnPrefix), rightValues)(name)

val prefixedLeftIgnoredColumns = leftIgnoredColumns.map(c => aliasLeft(c))
val prefixedRightIgnoredColumns = rightIgnoredColumns.map(c => aliasRight(c))

val nonIdColumns = options.diffMode match {
case DiffMode.ColumnByColumn =>
otherColumns.flatMap(c =>
valueColumns.flatMap(c =>
Seq(
leftValues(c).as(s"${options.leftColumnPrefix}_$c"),
rightValues(c).as(s"${options.rightColumnPrefix}_$c")
aliasLeft(c),
aliasRight(c)
)
) ++ ignoreColumns.flatMap(c =>
(if (leftIgnoredColumns.containsCaseSensitivity(c)) Seq(aliasLeft(c)) else Seq.empty) ++
(if (rightIgnoredColumns.containsCaseSensitivity(c)) Seq(aliasRight(c)) else Seq.empty)
)

case DiffMode.SideBySide =>
otherColumns.map(c => leftValues(c).as(s"${options.leftColumnPrefix}_$c")) ++
otherColumns.map(c => rightValues(c).as(s"${options.rightColumnPrefix}_$c"))
leftValueColumns.toSeq.map(c => aliasLeft(c)) ++ prefixedLeftIgnoredColumns ++
rightValueColumns.toSeq.map(c => aliasRight(c)) ++ prefixedRightIgnoredColumns

case DiffMode.LeftSide | DiffMode.RightSide =>
otherColumns.map(c =>
if (options.diffMode == DiffMode.LeftSide) leftValues(c).as(c) else rightValues(c).as(c)
)
// in left-side / right-side mode, we do not prefix columns
(
if (options.diffMode == DiffMode.LeftSide) valueColumns.map(alias(None, leftValues)) else valueColumns.map(alias(None, rightValues))
) ++ (
if (options.diffMode == DiffMode.LeftSide) leftIgnoredColumns.map(alias(None, leftValues)) else rightIgnoredColumns.map(alias(None, rightValues))
)
}
idColumns ++ valueColumns
idColumns ++ nonIdColumns
}

private def doDiff[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String] = Seq.empty): DataFrame = {
checkSchema(left, right, idColumns, ignoreColumns)

val pkColumns = if (idColumns.isEmpty) left.columns.toList else idColumns
val pkColumnsCs = pkColumns.map(handleConfiguredCaseSensitivity).toSet
val nonPkColumns = left.columns.filter(col => !pkColumnsCs.contains(handleConfiguredCaseSensitivity(col)))

val ignoreColumnsCs = ignoreColumns.map(handleConfiguredCaseSensitivity).toSet
val valueColumns = nonPkColumns.filter(col => !ignoreColumnsCs(handleConfiguredCaseSensitivity(col)))

val columns = left.columns.filter(c => !ignoreColumnsCs.contains(handleConfiguredCaseSensitivity(c))).toList
val pkColumns = if (idColumns.isEmpty) columns else idColumns
val pkColumnsCs = pkColumns.map(handleConfiguredCaseSensitivity).toSet
val valueColumns = columns.filter(col => !pkColumnsCs.contains(handleConfiguredCaseSensitivity(col)))

val existsColumnName = distinctPrefixFor(left.columns) + "exists"
val leftWithExists = left.withColumn(existsColumnName, lit(1))
Expand All @@ -206,7 +238,7 @@ class Differ(options: DiffOptions) {
otherwise(lit(options.nochangeDiffValue)).
as(options.diffColumn)

val diffColumns = getDiffColumns(pkColumns, nonPkColumns, left, right, ignoreColumns)
val diffColumns = getDiffColumns(pkColumns, valueColumns, left, right, ignoreColumns).map(_._2)
val changeColumn = getChangeColumn(existsColumnName, valueColumns, leftWithExists, rightWithExists)
// turn this column into a sequence of one or none column so we can easily concat it below with diffActionColumn and diffColumns
.map(Seq(_))
Expand All @@ -217,9 +249,9 @@ class Differ(options: DiffOptions) {
}

/**
* Returns a new DataFrame that contains the differences between two Datasets
* of the same type `T`. Both Datasets must contain the same set of column names and data types.
* The order of columns in the two Datasets is not important as columns are compared based on the
* Returns a new DataFrame that contains the differences between two Datasets of
* the same type `T`. Both Datasets must contain the same set of column names and data types.
* The order of columns in the two Datasets is not relevant as columns are compared based on the
* name, not the the position.
*
* Optional `id` columns are used to uniquely identify rows to compare. If values in any non-id
Expand Down Expand Up @@ -277,7 +309,7 @@ class Differ(options: DiffOptions) {
/**
* Returns a new DataFrame that contains the differences between two Datasets of
* similar types `T` and `U`. Both Datasets must contain the same set of column names and data types,
* except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not important as
* except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as
* columns are compared based on the name, not the the position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id
Expand Down Expand Up @@ -441,10 +473,10 @@ class Differ(options: DiffOptions) {
*/
def diffAs[T, U, V](left: Dataset[T], right: Dataset[U],
diffEncoder: Encoder[V], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[V] = {
val nonIdColumns = left.columns.diff(if (idColumns.isEmpty) left.columns.toList else idColumns)
val nonIdColumns = if (idColumns.isEmpty) Seq.empty else left.columns.diffCaseSensitivity(idColumns).diffCaseSensitivity(ignoreColumns).toSeq
val encColumns = diffEncoder.schema.fields.map(_.name)
val diffColumns = Seq(options.diffColumn) ++ idColumns ++ getDiffValueColumns(nonIdColumns, options.diffMode)
val extraColumns = encColumns.diff(diffColumns)
val diffColumns = Seq(options.diffColumn) ++ getDiffColumns(idColumns, nonIdColumns, left, right, ignoreColumns).map(_._1)
val extraColumns = encColumns.diffCaseSensitivity(diffColumns)

require(extraColumns.isEmpty,
s"Diff encoder's columns must be part of the diff result schema, " +
Expand Down Expand Up @@ -599,7 +631,7 @@ object Diff {
/**
* Returns a new DataFrame that contains the differences between two Datasets of
* similar types `T` and `U`. Both Datasets must contain the same set of column names and data types,
* except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not important as
* except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as
* columns are compared based on the name, not the the position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id
Expand Down Expand Up @@ -657,7 +689,7 @@ object Diff {
/**
* Returns a new DataFrame that contains the differences between two Datasets of
* similar types `T` and `U`. Both Datasets must contain the same set of column names and data types,
* except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not important as
* except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as
* columns are compared based on the name, not the the position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id
Expand Down
Loading

0 comments on commit 7277a90

Please sign in to comment.