diff --git a/docs/source/api/spatial-functions.rst b/docs/source/api/spatial-functions.rst index c9d93c01e..c3f541330 100644 --- a/docs/source/api/spatial-functions.rst +++ b/docs/source/api/spatial-functions.rst @@ -1036,7 +1036,7 @@ polyfill mosaicfill ********** -.. function:: mosaicfill(geometry, resolution) +.. function:: mosaicfill(geometry, resolution, keep_core_geometries) Generates: - a set of core indices that are fully contained by `geometry`; and @@ -1048,6 +1048,8 @@ mosaicfill :type geometry: Column :param resolution: Index resolution :type resolution: Column: Integer + :param keep_core_geometries: Whether to keep the core geometries or set them to null + :type keep_core_geometries: Column: Boolean :rtype: Column: ArrayType[MosaicType] :example: @@ -1058,7 +1060,7 @@ mosaicfill >>> df = spark.createDataFrame([{'wkt': 'MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))'}]) >>> df.select(mosaicfill('wkt', lit(0))).printSchema() root - |-- h3_mosaicfill(wkt, 0): mosaic (nullable = true) + |-- mosaicfill(wkt, 0): mosaic (nullable = true) | |-- chips: array (nullable = true) | | |-- element: mosaic_chip (containsNull = true) | | | |-- is_core: boolean (nullable = true) @@ -1068,7 +1070,7 @@ mosaicfill >>> df.select(mosaicfill('wkt', lit(0))).show() +---------------------+ - |h3_mosaicfill(wkt, 0)| + |mosaicfill(wkt, 0) | +---------------------+ | {[{false, 5774810...| +---------------------+ @@ -1078,7 +1080,7 @@ mosaicfill >>> val df = List(("MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))")).toDF("wkt") >>> df.select(mosaicfill($"wkt", lit(0))).printSchema root - |-- h3_mosaicfill(wkt, 0): mosaic (nullable = true) + |-- mosaicfill(wkt, 0): mosaic (nullable = true) | |-- chips: array (nullable = true) | | |-- element: mosaic_chip (containsNull = true) | | | |-- is_core: boolean (nullable = true) @@ -1087,7 +1089,7 @@ mosaicfill >>> df.select(mosaicfill($"wkt", lit(0))).show() +---------------------+ - |h3_mosaicfill(wkt, 0)| + |mosaicfill(wkt, 0) | +---------------------+ | {[{false, 5774810...| +---------------------+ @@ -1096,7 +1098,7 @@ mosaicfill >>> SELECT mosaicfill("MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", 0) +---------------------+ - |h3_mosaicfill(wkt, 0)| + |mosaicfill(wkt, 0) | +---------------------+ | {[{false, 5774810...| +---------------------+ @@ -1105,7 +1107,7 @@ mosaicfill mosaic_explode ************** -.. function:: mosaic_explode(geometry, resolution) +.. function:: mosaic_explode(geometry, resolution, keep_core_geometries) Returns the set of Mosaic chips covering the input `geometry` at `resolution`. @@ -1115,6 +1117,8 @@ mosaic_explode :type geometry: Column :param resolution: Index resolution :type resolution: Column: Integer + :param keep_core_geometries: Whether to keep the core geometries or set them to null + :type keep_core_geometries: Column: Boolean :rtype: Column: MosaicType :example: diff --git a/python/mosaic/api/functions.py b/python/mosaic/api/functions.py index 7cd34bc29..662d2119d 100644 --- a/python/mosaic/api/functions.py +++ b/python/mosaic/api/functions.py @@ -1,10 +1,11 @@ from pyspark.sql import Column from pyspark.sql.functions import _to_java_column as pyspark_to_java_column -from pyspark.sql.functions import col - +from pyspark.sql.functions import lit +from typing import Any from mosaic.config import config from mosaic.utils.types import ColumnOrName, as_typed_col + ##################### # Spatial functions # ##################### @@ -592,7 +593,7 @@ def polyfill(geom: ColumnOrName, resolution: ColumnOrName) -> Column: ) -def mosaic_explode(geom: ColumnOrName, resolution: ColumnOrName) -> Column: +def mosaic_explode(geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True) -> Column: """ Generates: - a set of core indices that are fully contained by `geom`; and @@ -604,21 +605,27 @@ def mosaic_explode(geom: ColumnOrName, resolution: ColumnOrName) -> Column: ---------- geom : Column resolution : Column (IntegerType) + keep_core_geometries : Column (BooleanType) | bool Returns ------- Column (StructType[is_core: BooleanType, h3: LongType, wkb: BinaryType]) - `wkb` in this struct represents a border chip geometry and is null for all 'core' chips. + `wkb` in this struct represents a border chip geometry and is null for all 'core' chips + if keep_core_geometries is set to False. """ + if(type(keep_core_geometries) == bool): + keep_core_geometries = lit(keep_core_geometries) + return config.mosaic_context.invoke_function( "mosaic_explode", pyspark_to_java_column(geom), pyspark_to_java_column(resolution), + pyspark_to_java_column(keep_core_geometries) ) -def mosaicfill(geom: ColumnOrName, resolution: ColumnOrName) -> Column: +def mosaicfill(geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True) -> Column: """ Generates: - a set of core indices that are fully contained by `geom`; and @@ -630,15 +637,22 @@ def mosaicfill(geom: ColumnOrName, resolution: ColumnOrName) -> Column: ---------- geom : Column resolution : Column (IntegerType) + keep_core_geometries : Column (BooleanType) | bool Returns ------- Column (ArrayType[StructType[is_core: BooleanType, h3: LongType, wkb: BinaryType]]) - `wkb` in this struct represents a border chip geometry and is null for all 'core' chips. + `wkb` in this struct represents a border chip geometry and is null for all 'core' chips + if keep_core_geometries is set to False. """ + + if(type(keep_core_geometries) == bool): + keep_core_geometries = lit(keep_core_geometries) + return config.mosaic_context.invoke_function( "mosaicfill", pyspark_to_java_column(geom), pyspark_to_java_column(resolution), + pyspark_to_java_column(keep_core_geometries) ) diff --git a/python/test/test_functions.py b/python/test/test_functions.py index a9a1d7192..e5c6be3ed 100644 --- a/python/test/test_functions.py +++ b/python/test/test_functions.py @@ -64,7 +64,11 @@ def test_st_bindings_happy_flow(self): .withColumn("index_geometry", api.index_geometry(lit(1))) .withColumn("polyfill", api.polyfill("wkt", lit(1))) .withColumn("mosaic_explode", api.mosaic_explode("wkt", lit(1))) + .withColumn("mosaic_explode_no_core_chips", api.mosaic_explode("wkt", lit(1), lit(False))) + .withColumn("mosaic_explode_no_core_chips_bool", api.mosaic_explode("wkt", lit(1), False)) .withColumn("mosaicfill", api.mosaicfill("wkt", lit(1))) + .withColumn("mosaicfill_no_core_chips", api.mosaicfill("wkt", lit(1), False)) + .withColumn("mosaicfill_no_core_chips_bool", api.mosaicfill("wkt", lit(1), lit(False))) .withColumn( "geom_with_srid", api.st_setsrid(api.st_geomfromwkt("wkt"), lit(4326)) ) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/Mosaic.scala b/src/main/scala/com/databricks/labs/mosaic/core/Mosaic.scala index 63f1ee5ba..5de17cca4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/Mosaic.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/Mosaic.scala @@ -17,7 +17,7 @@ import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum.{LINESTRING, */ object Mosaic { - def mosaicFill(geometry: MosaicGeometry, resolution: Int, indexSystem: IndexSystem, geometryAPI: GeometryAPI): Seq[MosaicChip] = { + def mosaicFill(geometry: MosaicGeometry, resolution: Int, keepCoreGeom: Boolean, indexSystem: IndexSystem, geometryAPI: GeometryAPI): Seq[MosaicChip] = { val radius = indexSystem.getBufferRadius(geometry, resolution, geometryAPI) @@ -34,8 +34,8 @@ object Mosaic { val coreIndices = indexSystem.polyfill(carvedGeometry, resolution) val borderIndices = indexSystem.polyfill(borderGeometry, resolution) - val coreChips = indexSystem.getCoreChips(coreIndices) - val borderChips = indexSystem.getBorderChips(geometry, borderIndices, geometryAPI) + val coreChips = indexSystem.getCoreChips(coreIndices, keepCoreGeom, geometryAPI) + val borderChips = indexSystem.getBorderChips(geometry, borderIndices, keepCoreGeom, geometryAPI) coreChips ++ borderChips } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/index/H3IndexSystem.scala b/src/main/scala/com/databricks/labs/mosaic/core/index/H3IndexSystem.scala index 683b030df..0992cc8d0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/index/H3IndexSystem.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/index/H3IndexSystem.scala @@ -128,12 +128,17 @@ object H3IndexSystem extends IndexSystem with Serializable { override def getBorderChips( geometry: MosaicGeometry, borderIndices: util.List[java.lang.Long], + keepCoreGeom: Boolean, geometryAPI: GeometryAPI ): Seq[MosaicChip] = { val intersections = for (index <- borderIndices.asScala) yield { val indexGeom = indexToGeometry(index, geometryAPI) - val chip = MosaicChip(isCore = false, index, indexGeom) - chip.intersection(geometry) + val intersect = geometry.intersection(indexGeom) + val isCore = intersect.equals(indexGeom) + + val chipGeom = if (!isCore || keepCoreGeom) intersect else null + + MosaicChip(isCore = isCore, index, chipGeom) } intersections.filterNot(_.isEmpty) } @@ -146,8 +151,11 @@ object H3IndexSystem extends IndexSystem with Serializable { * @return * A core area representation via [[MosaicChip]] set. */ - override def getCoreChips(coreIndices: util.List[lang.Long]): Seq[MosaicChip] = { - coreIndices.asScala.map(MosaicChip(true, _, null)) + override def getCoreChips(coreIndices: util.List[lang.Long], keepCoreGeom: Boolean, geometryAPI: GeometryAPI): Seq[MosaicChip] = { + coreIndices.asScala.map(index => { + val indexGeom = if (keepCoreGeom) indexToGeometry(index, geometryAPI) else null + MosaicChip(isCore = true, index, indexGeom) + }) } /** diff --git a/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala b/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala index e29a55e95..e5f4f7992 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala @@ -100,7 +100,7 @@ trait IndexSystem extends Serializable { * @return * A border area representation via [[MosaicChip]] set. */ - def getBorderChips(geometry: MosaicGeometry, borderIndices: util.List[java.lang.Long], geometryAPI: GeometryAPI): Seq[MosaicChip] + def getBorderChips(geometry: MosaicGeometry, borderIndices: util.List[java.lang.Long], keepCoreGeom: Boolean, geometryAPI: GeometryAPI): Seq[MosaicChip] /** * Return a set of [[MosaicChip]] instances computed based on the core @@ -113,7 +113,7 @@ trait IndexSystem extends Serializable { * @return * A core area representation via [[MosaicChip]] set. */ - def getCoreChips(coreIndices: util.List[java.lang.Long]): Seq[MosaicChip] + def getCoreChips(coreIndices: util.List[java.lang.Long], keepCoreGeom: Boolean, geometryAPI: GeometryAPI): Seq[MosaicChip] /** * Get the geometry corresponding to the index with the input id. diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicChip.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicChip.scala index 804d5b67a..d64d993af 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicChip.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicChip.scala @@ -17,26 +17,6 @@ import org.apache.spark.sql.catalyst.InternalRow */ case class MosaicChip(isCore: Boolean, index: Long, geom: MosaicGeometry) { - /** - * Perform an intersection with a geometry, and if intersection is non - * empty and the chip is not a core set chip then extract the chip - * geometry. - * - * @param other - * Geometry instance. - * @return - * A Mosaic Chip instance. - */ - def intersection(other: MosaicGeometry): MosaicChip = { - val intersect = other.intersection(geom) - val isCore = intersect.equals(geom) - if (isCore) { - MosaicChip(isCore, index, null) - } else { - MosaicChip(isCore, index, intersect) - } - } - /** * Indicates whether the chip is outside of the representation of the * geometry it was generated to represent (ie false positive index). diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplode.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplode.scala index 40ef15385..af3b3cdf9 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplode.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplode.scala @@ -68,11 +68,12 @@ object MosaicExplode { val geometry = geometryAPI.geometry(inputData, geomType) val resolution = inputData.getInt(1) + val keepCoreGeom = inputData.getBoolean(2) val chips = GeometryTypeEnum.fromString(geometry.getGeometryType) match { case LINESTRING => Mosaic.lineFill(geometry, resolution, indexSystem, geometryAPI) case MULTILINESTRING => Mosaic.lineFill(geometry, resolution, indexSystem, geometryAPI) - case _ => Mosaic.mosaicFill(geometry, resolution, indexSystem, geometryAPI) + case _ => Mosaic.mosaicFill(geometry, resolution, keepCoreGeom, indexSystem, geometryAPI) } chips.map(row => InternalRow.fromSeq(Seq(row.serialize))) @@ -92,14 +93,15 @@ object MosaicExplode { val fields = child.dataType.asInstanceOf[StructType].fields val geomType = fields.head val resolutionType = fields(1) + val keepCoreGeom = fields(2) - (geomType.dataType, resolutionType.dataType) match { - case (BinaryType, IntegerType) => TypeCheckResult.TypeCheckSuccess - case (StringType, IntegerType) => TypeCheckResult.TypeCheckSuccess - case (HexType, IntegerType) => TypeCheckResult.TypeCheckSuccess - case (InternalGeometryType, IntegerType) => TypeCheckResult.TypeCheckSuccess + (geomType.dataType, resolutionType.dataType, keepCoreGeom.dataType) match { + case (BinaryType, IntegerType, BooleanType) => TypeCheckResult.TypeCheckSuccess + case (StringType, IntegerType, BooleanType) => TypeCheckResult.TypeCheckSuccess + case (HexType, IntegerType, BooleanType) => TypeCheckResult.TypeCheckSuccess + case (InternalGeometryType, IntegerType, BooleanType) => TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( - s"Input to h3 mosaic explode should be (geometry, resolution) pair. " + + s"Input to h3 mosaic explode should be (geometry, resolution, keepCoreGeom) pair. " + s"Geometry type can be WKB, WKT, Hex or Coords. Provided type was: ${child.dataType.catalogString}" ) } @@ -122,14 +124,15 @@ object MosaicExplode { val fields = child.dataType.asInstanceOf[StructType].fields val geomType = fields.head val resolutionType = fields(1) + val keepCoreGeom = fields(2) - (geomType.dataType, resolutionType.dataType) match { - case (BinaryType, IntegerType) => StructType(Array(StructField("index", ChipType))) - case (StringType, IntegerType) => StructType(Array(StructField("index", ChipType))) - case (HexType, IntegerType) => StructType(Array(StructField("index", ChipType))) - case (InternalGeometryType, IntegerType) => StructType(Array(StructField("index", ChipType))) + (geomType.dataType, resolutionType.dataType, keepCoreGeom.dataType) match { + case (BinaryType, IntegerType, BooleanType) => StructType(Array(StructField("index", ChipType))) + case (StringType, IntegerType, BooleanType) => StructType(Array(StructField("index", ChipType))) + case (HexType, IntegerType, BooleanType) => StructType(Array(StructField("index", ChipType))) + case (InternalGeometryType, IntegerType, BooleanType) => StructType(Array(StructField("index", ChipType))) case _ => throw new Error( - s"Input to h3 mosaic explode should be (geometry, resolution) pair. " + + s"Input to h3 mosaic explode should be (geometry, resolution, keepCoreGeom) pair. " + s"Geometry type can be WKB, WKT, Hex or Coords. Provided type was: ${child.dataType.catalogString}" ) } @@ -149,14 +152,14 @@ object MosaicExplode { db.orNull, "mosaic_explode", """ - | _FUNC_(struct(geometry, resolution)) - Generates the h3 mosaic chips for the input geometry - | at a given resolution. Geometry and resolution are provided via struct wrapper to ensure + | _FUNC_(struct(geometry, resolution, keepCoreGeom)) - Generates the h3 mosaic chips for the input + | geometry at a given resolution. Geometry and resolution are provided via struct wrapper to ensure | UnaryExpression API is respected. """.stripMargin, "", """ | Examples: - | > SELECT _FUNC_(a, b); + | > SELECT _FUNC_(a, b, c); | {index_id, is_border, chip_geom} | {index_id, is_border, chip_geom} | ... diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/index/MosaicFill.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/index/MosaicFill.scala index 93ed8fc9b..25070658e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/index/MosaicFill.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/index/MosaicFill.scala @@ -10,7 +10,7 @@ import org.locationtech.jts.geom.Geometry import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{ - BinaryExpression, + TernaryExpression, ExpectsInputTypes, Expression, ExpressionDescription, @@ -30,33 +30,36 @@ import org.apache.spark.sql.types._ """, since = "1.0" ) -case class MosaicFill(geom: Expression, resolution: Expression, indexSystemName: String, geometryAPIName: String) - extends BinaryExpression +case class MosaicFill(geom: Expression, resolution: Expression, keepCoreGeom: Expression, indexSystemName: String, geometryAPIName: String) + extends TernaryExpression with ExpectsInputTypes with NullIntolerant with CodegenFallback { // noinspection DuplicatedCode override def inputTypes: Seq[DataType] = - (left.dataType, right.dataType) match { - case (BinaryType, IntegerType) => Seq(BinaryType, IntegerType) - case (StringType, IntegerType) => Seq(StringType, IntegerType) - case (HexType, IntegerType) => Seq(HexType, IntegerType) - case (InternalGeometryType, IntegerType) => Seq(InternalGeometryType, IntegerType) - case _ => throw new IllegalArgumentException(s"Not supported data type: (${left.dataType}, ${right.dataType}).") + (first.dataType, second.dataType, third.dataType) match { + case (BinaryType, IntegerType, BooleanType) => Seq(BinaryType, IntegerType, BooleanType) + case (StringType, IntegerType, BooleanType) => Seq(StringType, IntegerType, BooleanType) + case (HexType, IntegerType, BooleanType) => Seq(HexType, IntegerType, BooleanType) + case (InternalGeometryType, IntegerType, BooleanType) => Seq(InternalGeometryType, IntegerType, BooleanType) + case _ => + throw new IllegalArgumentException(s"Not supported data type: (${first.dataType}, ${second.dataType}, ${third.dataType}).") } - override def right: Expression = resolution + override def first: Expression = geom - override def left: Expression = geom + override def second: Expression = resolution + + override def third: Expression = keepCoreGeom /** Expression output DataType. */ override def dataType: DataType = MosaicType - override def toString: String = s"h3_mosaicfill($geom, $resolution)" + override def toString: String = s"mosaicfill($geom, $resolution, $keepCoreGeom)" /** Overridden to ensure [[Expression.sql]] is properly formatted. */ - override def prettyName: String = "h3_mosaicfill" + override def prettyName: String = "mosaicfill" /** * Type-wise differences in evaluation are only present on the input data @@ -69,22 +72,25 @@ case class MosaicFill(geom: Expression, resolution: Expression, indexSystemName: * Any instance containing the geometry. * @param input2 * Any instance containing the resolution + * @param input3 + * Any instance defining if core chips should be geometries or nulls * @return * A set of serialized * [[com.databricks.labs.mosaic.core.types.model.MosaicChip]]. */ // noinspection DuplicatedCode - override def nullSafeEval(input1: Any, input2: Any): Any = { + override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { val resolution: Int = H3IndexSystem.getResolution(input2) + val keepCoreGeom: Boolean = input3.asInstanceOf[Boolean] val indexSystem = IndexSystemID.getIndexSystem(IndexSystemID(indexSystemName)) val geometryAPI = GeometryAPI(geometryAPIName) - val geometry = geometryAPI.geometry(input1, left.dataType) + val geometry = geometryAPI.geometry(input1, first.dataType) val chips = GeometryTypeEnum.fromString(geometry.getGeometryType) match { case LINESTRING => Mosaic.lineFill(geometry, resolution, indexSystem, geometryAPI) case MULTILINESTRING => Mosaic.lineFill(geometry, resolution, indexSystem, geometryAPI) - case _ => Mosaic.mosaicFill(geometry, resolution, indexSystem, geometryAPI) + case _ => Mosaic.mosaicFill(geometry, resolution, keepCoreGeom, indexSystem, geometryAPI) } val serialized = InternalRow.fromSeq( @@ -97,14 +103,14 @@ case class MosaicFill(geom: Expression, resolution: Expression, indexSystemName: } override def makeCopy(newArgs: Array[AnyRef]): Expression = { - val asArray = newArgs.take(2).map(_.asInstanceOf[Expression]) - val res = MosaicFill(asArray(0), asArray(1), indexSystemName, geometryAPIName) + val asArray = newArgs.take(3).map(_.asInstanceOf[Expression]) + val res = MosaicFill(asArray(0), asArray(1), asArray(2), indexSystemName, geometryAPIName) res.copyTagsFrom(this) res } - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = - copy(geom = newLeft, resolution = newRight) + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(geom = newFirst, resolution = newSecond, keepCoreGeom = newThird) } @@ -115,14 +121,14 @@ object MosaicFill { new ExpressionInfo( classOf[IndexGeometry].getCanonicalName, db.orNull, - "mosaic_fill", + "mosaicfill", """ - | _FUNC_(geometry, resolution) - Returns the 2 set representation of geometry at resolution. + | _FUNC_(geometry, resolution, keepCoreGeom) - Returns the 2 set representation of geometry at resolution. """.stripMargin, "", """ | Examples: - | > SELECT _FUNC_(a, b); + | > SELECT _FUNC_(a, b, c); | [{index_id, is_border, chip_geom}, {index_id, is_border, chip_geom}, ..., {index_id, is_border, chip_geom}] | """.stripMargin, "", diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index ec3f77241..617dbc4a5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -284,12 +284,25 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends FunctionIdentifier("mosaic_explode", database), MosaicExplode.registryExpressionInfo(database), (exprs: Seq[Expression]) => - MosaicExplode(struct(ColumnAdapter(exprs(0)), ColumnAdapter(exprs(1))).expr, indexSystem.name, geometryAPI.name) + exprs match { + case e if e.length == 2 => + MosaicExplode(struct(ColumnAdapter(e(0)), ColumnAdapter(e(1)), lit(true)).expr, indexSystem.name, geometryAPI.name) + case e if e.length == 3 => + MosaicExplode( + struct(ColumnAdapter(e(0)), ColumnAdapter(e(1)), ColumnAdapter(e(2))).expr, + indexSystem.name, + geometryAPI.name + ) + } ) registry.registerFunction( FunctionIdentifier("mosaicfill", database), MosaicFill.registryExpressionInfo(database), - (exprs: Seq[Expression]) => MosaicFill(exprs(0), exprs(1), indexSystem.name, geometryAPI.name) + (exprs: Seq[Expression]) => + exprs match { + case e if e.length == 2 => MosaicFill(e(0), e(1), lit(true).expr, indexSystem.name, geometryAPI.name) + case e if e.length == 3 => MosaicFill(e(0), e(1), e(2), indexSystem.name, geometryAPI.name) + } ) registry.registerFunction( FunctionIdentifier("point_index_lonlat", database), @@ -397,13 +410,29 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends /** IndexSystem and GeometryAPI Specific methods */ def mosaic_explode(geom: Column, resolution: Column): Column = - ColumnAdapter(MosaicExplode(struct(geom, resolution).expr, indexSystem.name, geometryAPI.name)) + ColumnAdapter(MosaicExplode(struct(geom, resolution, lit(true)).expr, indexSystem.name, geometryAPI.name)) + def mosaic_explode(geom: Column, resolution: Column, keepCoreGeometries: Boolean): Column = + ColumnAdapter(MosaicExplode(struct(geom, resolution, lit(keepCoreGeometries)).expr, indexSystem.name, geometryAPI.name)) + def mosaic_explode(geom: Column, resolution: Column, keepCoreGeometries: Column): Column = + ColumnAdapter(MosaicExplode(struct(geom, resolution, keepCoreGeometries).expr, indexSystem.name, geometryAPI.name)) def mosaic_explode(geom: Column, resolution: Int): Column = - ColumnAdapter(MosaicExplode(struct(geom, lit(resolution)).expr, indexSystem.name, geometryAPI.name)) + ColumnAdapter(MosaicExplode(struct(geom, lit(resolution), lit(true)).expr, indexSystem.name, geometryAPI.name)) + def mosaic_explode(geom: Column, resolution: Int, keepCoreGeometries: Boolean): Column = + ColumnAdapter(MosaicExplode(struct(geom, lit(resolution), lit(keepCoreGeometries)).expr, indexSystem.name, geometryAPI.name)) + def mosaic_explode(geom: Column, resolution: Int, keepCoreGeometries: Column): Column = + ColumnAdapter(MosaicExplode(struct(geom, lit(resolution), keepCoreGeometries).expr, indexSystem.name, geometryAPI.name)) def mosaicfill(geom: Column, resolution: Column): Column = - ColumnAdapter(MosaicFill(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name)) + ColumnAdapter(MosaicFill(geom.expr, resolution.expr, lit(true).expr, indexSystem.name, geometryAPI.name)) def mosaicfill(geom: Column, resolution: Int): Column = - ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name)) + ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, lit(true).expr, indexSystem.name, geometryAPI.name)) + def mosaicfill(geom: Column, resolution: Column, keepCoreGeometries: Boolean): Column = + ColumnAdapter(MosaicFill(geom.expr, resolution.expr, lit(keepCoreGeometries).expr, indexSystem.name, geometryAPI.name)) + def mosaicfill(geom: Column, resolution: Int, keepCoreGeometries: Boolean): Column = + ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, lit(keepCoreGeometries).expr, indexSystem.name, geometryAPI.name)) + def mosaicfill(geom: Column, resolution: Column, keepCoreGeometries: Column): Column = + ColumnAdapter(MosaicFill(geom.expr, resolution.expr, keepCoreGeometries.expr, indexSystem.name, geometryAPI.name)) + def mosaicfill(geom: Column, resolution: Int, keepCoreGeometries: Column): Column = + ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, keepCoreGeometries.expr, indexSystem.name, geometryAPI.name)) def point_index_geom(point: Column, resolution: Column): Column = ColumnAdapter(PointIndexGeom(point.expr, resolution.expr, indexSystem.name, geometryAPI.name)) def point_index_geom(point: Column, resolution: Int): Column = diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeBehaviors.scala index ec4c3d198..b1c3c638d 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeBehaviors.scala @@ -4,9 +4,10 @@ import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.test.mocks.{getBoroughs, getWKTRowsDf} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers._ - -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructField, StructType} + trait MosaicExplodeBehaviors { this: AnyFlatSpec => @@ -37,6 +38,67 @@ trait MosaicExplodeBehaviors { boroughs.collect().length should be < mosaics2.length } + def wktDecomposeNoNulls(mosaicContext: => MosaicContext, spark: => SparkSession): Unit = { + val mc = mosaicContext + import mc.functions._ + mosaicContext.register(spark) + + val rdd = spark.sparkContext.makeRDD(Seq( + Row("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))") + )) + val schema = StructType( + List( + StructField("wkt", StringType) + ) + ) + val df = spark.createDataFrame(rdd, schema) + + val noEmptyChips = df + .select( + mosaic_explode(col("wkt"), 4, true) + ) + .filter(col("index.wkb").isNull) + .count() + + noEmptyChips should equal(0) + + val emptyChips = df + .select( + mosaic_explode(col("wkt"), 4, false) + ) + .filter(col("index.wkb").isNull) + + emptyChips.collect().length should be > 0 + } + + def wktDecomposeKeepCoreParamExpression(mosaicContext: => MosaicContext, spark: => SparkSession): Unit = { + val mc = mosaicContext + import mc.functions._ + mosaicContext.register(spark) + + val rdd = spark.sparkContext.makeRDD(Seq( + Row("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))") + )) + val schema = StructType( + List( + StructField("wkt", StringType) + ) + ) + val df = spark.createDataFrame(rdd, schema) + + val noEmptyChips = df + .select( + expr("mosaic_explode(wkt, 4, true)") + ) + noEmptyChips.collect().length should be > 0 + + val noEmptyChips_2 = df + .select( + expr("mosaic_explode(wkt, 4, false)") + ) + noEmptyChips_2.collect().length should be > 0 + } + def lineDecompose(mosaicContext: => MosaicContext, spark: => SparkSession): Unit = { val mc = mosaicContext import mc.functions._ @@ -61,6 +123,7 @@ trait MosaicExplodeBehaviors { .collect() wktRows.collect().length should be < mosaics2.length + } def wkbDecompose(mosaicContext: => MosaicContext, spark: => SparkSession): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicFillBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicFillBehaviors.scala index 18351bf03..002b70466 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicFillBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicFillBehaviors.scala @@ -115,4 +115,31 @@ trait MosaicFillBehaviors { boroughs.collect().length shouldEqual mosaics2.length } + + def wktMosaicFillKeepCoreGeom(mosaicContext: => MosaicContext, spark: => SparkSession): Unit = { + val mc = mosaicContext + import mc.functions._ + mosaicContext.register(spark) + + val boroughs: DataFrame = getBoroughs + + val mosaics = boroughs + .select( + mosaicfill(col("wkt"), 11, true) + ) + .collect() + + boroughs.collect().length shouldEqual mosaics.length + + boroughs.createOrReplaceTempView("boroughs") + + val mosaics2 = spark + .sql(""" + |select mosaicfill(wkt, 11, true) from boroughs + |""".stripMargin) + .collect() + + boroughs.collect().length shouldEqual mosaics2.length + } + } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/TestMosaicExplode.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/TestMosaicExplode.scala index 47aac5a04..df922755f 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/TestMosaicExplode.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/TestMosaicExplode.scala @@ -13,6 +13,16 @@ class TestMosaicExplode extends AnyFlatSpec with MosaicExplodeBehaviors with Spa it should behave like wktDecompose(MosaicContext.build(H3IndexSystem, JTS), spark) } + "Mosaic_Explode" should "decompose wkt geometries for any index system and any geometry API with SQL expr" in { + it should behave like wktDecomposeKeepCoreParamExpression(MosaicContext.build(H3IndexSystem, ESRI), spark) + it should behave like wktDecomposeKeepCoreParamExpression(MosaicContext.build(H3IndexSystem, JTS), spark) + } + + "Mosaic_Explode" should "decompose wkt geometries with no null for any index system and any geometry API" in { + it should behave like wktDecomposeNoNulls(MosaicContext.build(H3IndexSystem, ESRI), spark) + it should behave like wktDecomposeNoNulls(MosaicContext.build(H3IndexSystem, JTS), spark) + } + "Mosaic_Explode" should "decompose wkb geometries for any index system and any geometry API" in { it should behave like wkbDecompose(MosaicContext.build(H3IndexSystem, ESRI), spark) it should behave like wkbDecompose(MosaicContext.build(H3IndexSystem, JTS), spark) diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/TestMosaicFill.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/TestMosaicFill.scala index 41add5e90..827492bae 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/TestMosaicFill.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/TestMosaicFill.scala @@ -28,4 +28,9 @@ class TestMosaicFill extends AnyFlatSpec with MosaicFillBehaviors with SparkSuit it should behave like coordsMosaicFill(MosaicContext.build(H3IndexSystem, JTS), spark) } + "MosaicFill" should "fill wkt geometries with keepCoreGeom parameter" in { + it should behave like wktMosaicFillKeepCoreGeom(MosaicContext.build(H3IndexSystem, ESRI), spark) + it should behave like wktMosaicFillKeepCoreGeom(MosaicContext.build(H3IndexSystem, JTS), spark) + } + }