Skip to content

Commit

Permalink
Correctly recurse into nested arrays & maps in add/drop columns
Browse files Browse the repository at this point in the history
It is not possible today in Delta tables to add or drop nested fields under two or more levels of directly nested arrays or maps.
The following is a valid use case but fails today:
```
CREATE TABLE test (data array<array<struct<a: int>>>)
ALTER TABLE test ADD COLUMNS (data.element.element.b string)
```

This change updates helper methods `findColumnPosition`, `addColumn` and `dropColumn` in `SchemaUtils` to correctly recurse into directly nested maps and arrays.

Note that changes in Spark are also required for `ALTER TABLE ADD/CHANGE/DROP COLUMN`  to work: apache/spark#40879. The fix is merged in Spark but will only be available in Delta in the next Spark release.

In addition, `findColumnPosition` which currently both returns the position of nested field and the size of its parent, making it overly complex, is split into two distinct and generic methods: `findColumnPosition` and `getNestedTypeFromPosition`.

- Tests for `findColumnPosition`, `addColumn` and `dropColumn` with two levels of nested maps and arrays are added to `SchemaUtilsSuite`. Other cases for these methods are already covered by existing tests.
- Tested locally that  ALTER TABLE ADD/CHANGE/DROP COLUMN(S) works correctly with Spark fix apache/spark#40879
- Added missing tests coverage for ALTER TABLE ADD/CHANGE/DROP COLUMN(S) with a single map or array.

Closes delta-io#1731

GitOrigin-RevId: 53ed05813f4002ae986926506254d780e2ecddfa
  • Loading branch information
johanl-db authored and allisonport-db committed May 10, 2023
1 parent 5c3f4d3 commit 243c0eb
Show file tree
Hide file tree
Showing 8 changed files with 841 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1959,7 +1959,7 @@ trait DeltaErrorsBase
new DeltaAnalysisException("DELTA_UNSUPPORTED_DROP_COLUMN", Array(adviceMsg))
}

def dropNestedColumnsFromNonStructTypeException(struct : StructField) : Throwable = {
def dropNestedColumnsFromNonStructTypeException(struct : DataType) : Throwable = {
new DeltaAnalysisException(
errorClass = "DELTA_UNSUPPORTED_DROP_NESTED_COLUMN_FROM_NON_STRUCT_TYPE",
messageParameters = Array(s"$struct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,19 @@ case class AlterTableAddColumnsDeltaCommand(
val resolver = sparkSession.sessionState.conf.resolver
val newSchema = colsToAddWithPosition.foldLeft(oldSchema) {
case (schema, QualifiedColTypeWithPosition(columnPath, column, None)) =>
val (parentPosition, lastSize) =
SchemaUtils.findColumnPosition(columnPath, schema, resolver)
SchemaUtils.addColumn(schema, column, parentPosition :+ lastSize)
val parentPosition = SchemaUtils.findColumnPosition(columnPath, schema, resolver)
val insertPosition = SchemaUtils.getNestedTypeFromPosition(schema, parentPosition) match {
case s: StructType => s.size
case other =>
throw DeltaErrors.addColumnParentNotStructException(column, other)
}
SchemaUtils.addColumn(schema, column, parentPosition :+ insertPosition)
case (schema, QualifiedColTypeWithPosition(columnPath, column, Some(_: First))) =>
val (parentPosition, _) = SchemaUtils.findColumnPosition(columnPath, schema, resolver)
val parentPosition = SchemaUtils.findColumnPosition(columnPath, schema, resolver)
SchemaUtils.addColumn(schema, column, parentPosition :+ 0)
case (schema,
QualifiedColTypeWithPosition(columnPath, column, Some(after: After))) =>
val (prevPosition, _) =
val prevPosition =
SchemaUtils.findColumnPosition(columnPath :+ after.column, schema, resolver)
val position = prevPosition.init :+ (prevPosition.last + 1)
SchemaUtils.addColumn(schema, column, position)
Expand Down Expand Up @@ -294,7 +298,7 @@ case class AlterTableDropColumnsDeltaCommand(
throw DeltaErrors.dropColumnNotSupported(suggestUpgrade = true)
}
val newSchema = columnsToDrop.foldLeft(metadata.schema) { case (schema, columnPath) =>
val (parentPosition, _) =
val parentPosition =
SchemaUtils.findColumnPosition(
columnPath, schema, sparkSession.sessionState.conf.resolver)
SchemaUtils.dropColumn(schema, parentPosition)._1
Expand Down
251 changes: 147 additions & 104 deletions core/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ object SchemaUtils extends DeltaLogging {
}

/**
* Returns the given column's ordinal within the given `schema` and the size of the last schema
* size. The length of the returned position will be as long as how nested the column is.
* Returns the path of the given column in `schema` as a list of ordinals (0-based), each value
* representing the position at the current nesting level starting from the root.
*
* For ArrayType: accessing the array's element adds a position 0 to the position list.
* e.g. accessing a.element.y would have the result -> Seq(..., positionOfA, 0, positionOfY)
Expand All @@ -538,65 +538,109 @@ object SchemaUtils extends DeltaLogging {
def findColumnPosition(
column: Seq[String],
schema: StructType,
resolver: Resolver = DELTA_COL_RESOLVER): (Seq[Int], Int) = {
def find(column: Seq[String], schema: StructType, stack: Seq[String]): (Seq[Int], Int) = {
if (column.isEmpty) return (Nil, schema.size)
val thisCol = column.head
lazy val columnPath = UnresolvedAttribute(stack :+ thisCol).name
val pos = schema.indexWhere(f => resolver(f.name, thisCol))
if (pos == -1) {
throw new IndexOutOfBoundsException(columnPath)
}
val colTail = column.tail
val (children, lastSize) = (colTail, schema(pos).dataType) match {
case (_, s: StructType) =>
find(colTail, s, stack :+ thisCol)
case (Seq("element", _ @ _*), ArrayType(s: StructType, _)) =>
val (child, size) = find(colTail.tail, s, stack :+ thisCol)
(ARRAY_ELEMENT_INDEX +: child, size)
case (Seq(), ArrayType(s: StructType, _)) =>
find(colTail, s, stack :+ thisCol)
case (Seq(), ArrayType(_, _)) =>
(Nil, 0)
case (_, ArrayType(_, _)) =>
throw DeltaErrors.incorrectArrayAccessByName(
prettyFieldName(stack ++ Seq(thisCol, "element")),
prettyFieldName(stack ++ Seq(thisCol)))
case (Seq(), MapType(_, _, _)) =>
(Nil, 2)
case (Seq("key", _ @ _*), MapType(keyType: StructType, _, _)) =>
val (child, size) = find(colTail.tail, keyType, stack :+ thisCol)
(MAP_KEY_INDEX +: child, size)
case (Seq("key"), MapType(_, _, _)) =>
(Seq(MAP_KEY_INDEX), 0)
case (Seq("value", _ @ _*), MapType(_, valueType: StructType, _)) =>
val (child, size) = find(colTail.tail, valueType, stack :+ thisCol)
(MAP_VALUE_INDEX +: child, size)
case (Seq("value"), MapType(_, _, _)) =>
(Seq(MAP_VALUE_INDEX), 0)
case (_, MapType(_, _, _)) =>
throw DeltaErrors.foundMapTypeColumnException(
prettyFieldName(stack ++ Seq(thisCol, "key")),
prettyFieldName(stack ++ Seq(thisCol, "value")))
case (_, o) =>
if (column.length > 1) {
throw DeltaErrors.columnPathNotNested(columnPath, o, column)
resolver: Resolver = DELTA_COL_RESOLVER): Seq[Int] = {
def findRecursively(
searchPath: Seq[String],
currentType: DataType,
currentPath: Seq[String] = Nil): Seq[Int] = {
if (searchPath.isEmpty) return Nil

val currentFieldName = searchPath.head
val currentPathWithNestedField = currentPath :+ currentFieldName
(currentType, currentFieldName) match {
case (struct: StructType, _) =>
lazy val columnPath = UnresolvedAttribute(currentPathWithNestedField).name
val pos = struct.indexWhere(f => resolver(f.name, currentFieldName))
if (pos == -1) {
throw DeltaErrors.columnNotInSchemaException(columnPath, schema)
}
(Nil, 0)
val childPosition = findRecursively(
searchPath = searchPath.tail,
currentType = struct(pos).dataType,
currentPath = currentPathWithNestedField)
pos +: childPosition

case (map: MapType, "key") =>
val childPosition = findRecursively(
searchPath = searchPath.tail,
currentType = map.keyType,
currentPath = currentPathWithNestedField)
MAP_KEY_INDEX +: childPosition

case (map: MapType, "value") =>
val childPosition = findRecursively(
searchPath = searchPath.tail,
currentType = map.valueType,
currentPath = currentPathWithNestedField)
MAP_VALUE_INDEX +: childPosition

case (_: MapType, _) =>
throw DeltaErrors.foundMapTypeColumnException(
prettyFieldName(currentPath :+ "key"),
prettyFieldName(currentPath :+ "value"))

case (array: ArrayType, "element") =>
val childPosition = findRecursively(
searchPath = searchPath.tail,
currentType = array.elementType,
currentPath = currentPathWithNestedField)
ARRAY_ELEMENT_INDEX +: childPosition

case (_: ArrayType, _) =>
throw DeltaErrors.incorrectArrayAccessByName(
prettyFieldName(currentPath :+ "element"),
prettyFieldName(currentPath))
case _ =>
throw DeltaErrors.columnPathNotNested(currentFieldName, currentType, currentPath)
}
(Seq(pos) ++ children, lastSize)
}

try {
find(column, schema, Nil)
findRecursively(column, schema)
} catch {
case i: IndexOutOfBoundsException =>
throw DeltaErrors.columnNotInSchemaException(i.getMessage, schema)
case e: AnalysisException =>
throw new AnalysisException(e.getMessage + s":\n${schema.treeString}")
}
}

/**
* Returns the nested field at the given position in `parent`. See [[findColumnPosition]] for the
* representation used for `position`.
* @param parent The field used for the lookup.
* @param position A list of ordinals (0-based) representing the path to the nested field in
* `parent`.
*/
def getNestedFieldFromPosition(parent: StructField, position: Seq[Int]): StructField = {
if (position.isEmpty) return parent

val fieldPos = position.head
parent.dataType match {
case struct: StructType if fieldPos >= 0 && fieldPos < struct.size =>
getNestedFieldFromPosition(struct(fieldPos), position.tail)
case map: MapType if fieldPos == MAP_KEY_INDEX =>
getNestedFieldFromPosition(StructField("key", map.keyType), position.tail)
case map: MapType if fieldPos == MAP_VALUE_INDEX =>
getNestedFieldFromPosition(StructField("value", map.valueType), position.tail)
case array: ArrayType if fieldPos == ARRAY_ELEMENT_INDEX =>
getNestedFieldFromPosition(StructField("element", array.elementType), position.tail)
case _: StructType | _: ArrayType | _: MapType =>
throw new IllegalArgumentException(
s"Invalid child position $fieldPos in ${parent.dataType}")
case other =>
throw new IllegalArgumentException(s"Invalid indexing into non-nested type $other")
}
}

/**
* Returns the nested type at the given position in `schema`. See [[findColumnPosition]] for the
* representation used for `position`.
* @param parent The root schema used for the lookup.
* @param position A list of ordinals (0-based) representing the path to the nested field in
* `parent`.
*/
def getNestedTypeFromPosition(schema: StructType, position: Seq[Int]): DataType =
getNestedFieldFromPosition(StructField("schema", schema), position).dataType

/**
* Pretty print the column path passed in.
*/
Expand All @@ -616,6 +660,24 @@ object SchemaUtils extends DeltaLogging {
* result: <a:STRUCT<a1,a2,a3>, b,c:STRUCT<c1,**c2**,c3>>
*/
def addColumn(schema: StructType, column: StructField, position: Seq[Int]): StructType = {
def addColumnInChild(parent: DataType, column: StructField, position: Seq[Int]): DataType = {
require(position.nonEmpty, s"Don't know where to add the column $column")
parent match {
case struct: StructType =>
addColumn(struct, column, position)
case map: MapType if position.head == MAP_KEY_INDEX =>
map.copy(keyType = addColumnInChild(map.keyType, column, position.tail))
case map: MapType if position.head == MAP_VALUE_INDEX =>
map.copy(valueType = addColumnInChild(map.valueType, column, position.tail))
case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX =>
array.copy(elementType = addColumnInChild(array.elementType, column, position.tail))
case _: ArrayType =>
throw DeltaErrors.incorrectArrayAccess()
case other =>
throw DeltaErrors.addColumnParentNotStructException(column, other)
}
}

require(position.nonEmpty, s"Don't know where to add the column $column")
val slicePosition = position.head
if (slicePosition < 0) {
Expand All @@ -632,53 +694,16 @@ object SchemaUtils extends DeltaLogging {
}
return StructType(schema :+ column)
}
val pre = schema.take(slicePosition)
val (pre, post) = schema.splitAt(slicePosition)
if (position.length > 1) {
val posTail = position.tail
val mid = schema(slicePosition) match {
case StructField(name, f: StructType, nullable, metadata) =>
if (!column.nullable && nullable) {
throw DeltaErrors.nullableParentWithNotNullNestedField
}
StructField(
name,
addColumn(f, column, posTail),
nullable,
metadata)
case StructField(name, ArrayType(f: StructType, containsNull), nullable, metadata) =>
if (!column.nullable && nullable) {
throw DeltaErrors.nullableParentWithNotNullNestedField
}

if (posTail.head != ARRAY_ELEMENT_INDEX) {
throw DeltaErrors.incorrectArrayAccess()
}

StructField(
name,
ArrayType(addColumn(f, column, posTail.tail), containsNull),
nullable,
metadata)
case StructField(name, map @ MapType(_, _, _), nullable, metadata) =>
if (!column.nullable && nullable) {
throw DeltaErrors.nullableParentWithNotNullNestedField
}

val addedMap = (posTail.head, map) match {
case (MAP_KEY_INDEX, MapType(key: StructType, v, nullability)) =>
MapType(addColumn(key, column, posTail.tail), v, nullability)
case (MAP_VALUE_INDEX, MapType(k, value: StructType, nullability)) =>
MapType(k, addColumn(value, column, posTail.tail), nullability)
case _ =>
throw DeltaErrors.addColumnParentNotStructException(column, IntegerType)
}
StructField(name, addedMap, nullable, metadata)
case o =>
throw DeltaErrors.addColumnParentNotStructException(column, o.dataType)
val field = post.head
if (!column.nullable && field.nullable) {
throw DeltaErrors.nullableParentWithNotNullNestedField
}
StructType(pre ++ Seq(mid) ++ schema.slice(slicePosition + 1, length))
val mid = field.copy(dataType = addColumnInChild(field.dataType, column, position.tail))
StructType(pre ++ Seq(mid) ++ post.tail)
} else {
StructType(pre ++ Seq(column) ++ schema.slice(slicePosition, length))
StructType(pre ++ Seq(column) ++ post)
}
}

Expand All @@ -693,6 +718,27 @@ object SchemaUtils extends DeltaLogging {
* result: <a:STRUCT<a1,a2,a3>, b,c:STRUCT<c1,c3>>
*/
def dropColumn(schema: StructType, position: Seq[Int]): (StructType, StructField) = {
def dropColumnInChild(parent: DataType, position: Seq[Int]): (DataType, StructField) = {
require(position.nonEmpty, s"Don't know where to drop the column")
parent match {
case struct: StructType =>
dropColumn(struct, position)
case map: MapType if position.head == MAP_KEY_INDEX =>
val (newKeyType, droppedColumn) = dropColumnInChild(map.keyType, position.tail)
map.copy(keyType = newKeyType) -> droppedColumn
case map: MapType if position.head == MAP_VALUE_INDEX =>
val (newValueType, droppedColumn) = dropColumnInChild(map.valueType, position.tail)
map.copy(valueType = newValueType) -> droppedColumn
case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX =>
val (newElementType, droppedColumn) = dropColumnInChild(array.elementType, position.tail)
array.copy(elementType = newElementType) -> droppedColumn
case _: ArrayType =>
throw DeltaErrors.incorrectArrayAccess()
case other =>
throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(other)
}
}

require(position.nonEmpty, "Don't know where to drop the column")
val slicePosition = position.head
if (slicePosition < 0) {
Expand All @@ -702,22 +748,19 @@ object SchemaUtils extends DeltaLogging {
if (slicePosition >= length) {
throw DeltaErrors.indexLargerOrEqualThanStruct(slicePosition, length)
}
val pre = schema.take(slicePosition)
val (pre, post) = schema.splitAt(slicePosition)
val field = post.head
if (position.length > 1) {
val (mid, original) = schema(slicePosition) match {
case StructField(name, f: StructType, nullable, metadata) =>
val (dropped, original) = dropColumn(f, position.tail)
(StructField(name, dropped, nullable, metadata), original)
case o =>
throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(o)
}
(StructType(pre ++ Seq(mid) ++ schema.slice(slicePosition + 1, length)), original)
val (newType, droppedColumn) = dropColumnInChild(field.dataType, position.tail)
val mid = field.copy(dataType = newType)

StructType(pre ++ Seq(mid) ++ post.tail) -> droppedColumn
} else {
if (length == 1) {
throw new AnalysisException(
"Cannot drop column from a struct type with a single field: " + schema)
}
(StructType(pre ++ schema.slice(slicePosition + 1, length)), schema(slicePosition))
StructType(pre ++ post.tail) -> field
}
}

Expand Down
Loading

0 comments on commit 243c0eb

Please sign in to comment.