Skip to content

Commit

Permalink
Merge pull request #105 from databrickslabs/feature/st_buffer
Browse files Browse the repository at this point in the history
Feature/st buffer
  • Loading branch information
milos-colic authored Apr 28, 2022
2 parents 5b53b7b + 8d23f32 commit 5665387
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 3 deletions.
48 changes: 48 additions & 0 deletions docs/source/api/spatial-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,54 @@ st_area

.. note:: Results of this function are always expressed in the original units of the input geometry.




st_buffer
*********

.. function:: st_buffer(col)

Buffer the input geometry by radius `radius` and return a new, buffered geometry.

:param col: Geometry
:type col: Column
:param radius: Double
:type radius: Column (DoubleType)
:rtype: Column: Geometry

:example:

.. tabs::
.. code-tab:: py

>>> df = spark.createDataFrame([{'wkt': 'POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))'}])
>>> df.select(st_buffer('wkt', lit(2.))).show()
+--------------------+
| st_buffer(wkt, 2.0)|
+--------------------+
|POLYGON ((29.1055...|
+--------------------+

.. code-tab:: scala

>>> val df = List(("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")).toDF("wkt")
>>> df.select(st_buffer($"wkt", 2d)).show()
+--------------------+
| st_buffer(wkt, 2.0)|
+--------------------+
|POLYGON ((29.1055...|
+--------------------+

.. code-tab:: sql

>>> SELECT st_buffer("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))", 2d)
+--------------------+
| st_buffer(wkt, 2.0)|
+--------------------+
|POLYGON ((29.1055...|
+--------------------+

st_perimeter
************

Expand Down
2 changes: 1 addition & 1 deletion python/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Databricks
![mosaic-logo](src/main/resources/mosaic_logo.png)
![mosaic-logo](../src/main/resources/mosaic_logo.png)

An extension to the [Apache Spark](https://spark.apache.org/) framework that allows easy and fast processing of very large geospatial datasets.

Expand Down
21 changes: 21 additions & 0 deletions python/mosaic/api/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,27 @@ def st_convexhull(geom: ColumnOrName) -> Column:
"st_convexhull", pyspark_to_java_column(geom)
)

def st_buffer(geom: ColumnOrName, radius: ColumnOrName) -> Column:
"""
Compute the buffered geometry based on geom and radius.
Parameters
----------
geom : Column
The input geometry
radius : Column
The radius of buffering
Returns
-------
Column
A geometry
"""
return config.mosaic_context.invoke_function(
"st_buffer", pyspark_to_java_column(geom), pyspark_to_java_column(radius)
)


def st_dump(geom: ColumnOrName) -> Column:
"""
Expand Down
1 change: 1 addition & 0 deletions python/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_st_bindings_happy_flow(self):
result = (
df.withColumn("st_area", api.st_area("wkt"))
.withColumn("st_length", api.st_length("wkt"))
.withColumn("st_buffer", api.st_buffer("wkt", lit(1.1)))
.withColumn("st_perimeter", api.st_perimeter("wkt"))
.withColumn("st_convexhull", api.st_convexhull("wkt"))
.withColumn("st_dump", api.st_dump("wkt"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package com.databricks.labs.mosaic.expressions.geometry

import com.databricks.labs.mosaic.codegen.format.ConvertToCodeGen
import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.esri.core.geometry.ogc.OGCGeometry
import org.locationtech.jts.geom.{Geometry => JTSGeometry}

import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionInfo, NullIntolerant}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.DataType

case class ST_Buffer(inputGeom: Expression, radius: Expression, geometryAPIName: String) extends BinaryExpression with NullIntolerant {

override def left: Expression = inputGeom

override def right: Expression = radius

override def dataType: DataType = inputGeom.dataType

override def nullSafeEval(geomRow: Any, radiusRow: Any): Any = {
val geometryAPI = GeometryAPI(geometryAPIName)
val geometry = geometryAPI.geometry(geomRow, inputGeom.dataType)
val radiusVal = radiusRow.asInstanceOf[Double]
val buffered = geometry.buffer(radiusVal)
geometryAPI.serialize(buffered, inputGeom.dataType)
}

override def makeCopy(newArgs: Array[AnyRef]): Expression = {
val asArray = newArgs.take(2).map(_.asInstanceOf[Expression])
val res = ST_Buffer(asArray.head, asArray(1), geometryAPIName)
res.copyTagsFrom(this)
res
}

override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
nullSafeCodeGen(
ctx,
ev,
(leftEval, rightEval) => {
val geometryAPI = GeometryAPI.apply(geometryAPIName)
val buffered = ctx.freshName("buffered")
val ogcPolygonClass = classOf[OGCGeometry].getName
val jtsPolygonClass = classOf[JTSGeometry].getName
val (inCode, geomInRef) = ConvertToCodeGen.readGeometryCode(ctx, leftEval, inputGeom.dataType, geometryAPI)
val (outCode, outGeomRef) = ConvertToCodeGen.writeGeometryCode(ctx, buffered, inputGeom.dataType, geometryAPI)
// not merged into the same code block due to JTS IOException throwing
// ESRI code will always remain simpler
geometryAPIName match {
case "ESRI" => s"""
|$inCode
|$ogcPolygonClass $buffered = $geomInRef.buffer($rightEval);
|$outCode
|${ev.value} = $outGeomRef;
|""".stripMargin
case "JTS" => s"""
|try {
|$inCode
|$jtsPolygonClass $buffered = $geomInRef.buffer($rightEval);
|$outCode
|${ev.value} = $outGeomRef;
|} catch (Exception e) {
| throw e;
|}
|""".stripMargin

}
}
)

}

object ST_Buffer {

/** Entry to use in the function registry. */
def registryExpressionInfo(db: Option[String]): ExpressionInfo =
new ExpressionInfo(
classOf[ST_Buffer].getCanonicalName,
db.orNull,
"st_buffer",
"""
| _FUNC_(expr1, expr2) - Returns expr1 buffered by expr2.
""".stripMargin,
"",
"""
| Examples:
| > SELECT _FUNC_(a, b);
| POLYGON((1 1, 2 2, 3 3 ....))
| """.stripMargin,
"",
"misc_funcs",
"1.0",
"",
"built-in"
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
ST_ConvexHull.registryExpressionInfo(database),
(exprs: Seq[Expression]) => ST_ConvexHull(exprs(0), geometryAPI.name)
)
registry.registerFunction(
FunctionIdentifier("st_buffer", database),
ST_Buffer.registryExpressionInfo(database),
(exprs: Seq[Expression]) => ST_Buffer(exprs(0), exprs(1), geometryAPI.name)
)
registry.registerFunction(
FunctionIdentifier("st_numpoints", database),
ST_NumPoints.registryExpressionInfo(database),
Expand Down Expand Up @@ -388,6 +393,8 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
ColumnAdapter(ST_Scale(geom1.expr, xd.expr, yd.expr, geometryAPI.name))
def st_rotate(geom1: Column, td: Column): Column = ColumnAdapter(ST_Rotate(geom1.expr, td.expr, geometryAPI.name))
def st_convexhull(geom: Column): Column = ColumnAdapter(ST_ConvexHull(geom.expr, geometryAPI.name))
def st_buffer(geom: Column, radius: Column): Column = ColumnAdapter(ST_Buffer(geom.expr, radius.expr, geometryAPI.name))
def st_buffer(geom: Column, radius: Double): Column = ColumnAdapter(ST_Buffer(geom.expr, lit(radius).expr, geometryAPI.name))
def st_numpoints(geom: Column): Column = ColumnAdapter(ST_NumPoints(geom.expr, geometryAPI.name))
def st_intersects(left: Column, right: Column): Column = ColumnAdapter(ST_Intersects(left.expr, right.expr, geometryAPI.name))
def st_intersection(left: Column, right: Column): Column = ColumnAdapter(ST_Intersection(left.expr, right.expr, geometryAPI.name))
Expand Down Expand Up @@ -467,12 +474,12 @@ object MosaicContext {

def geometryAPI: GeometryAPI = context.getGeometryAPI

def indexSystem: IndexSystem = context.getIndexSystem

def context: MosaicContext =
instance match {
case Some(context) => context
case None => throw new IllegalStateException("MosaicContext was not built.")
}

def indexSystem: IndexSystem = context.getIndexSystem

}
Original file line number Diff line number Diff line change
Expand Up @@ -412,4 +412,60 @@ trait GeometryProcessorsBehaviors { this: AnyFlatSpec =>
result.collect().length > 0 shouldBe true
}

def bufferCalculation(mosaicContext: => MosaicContext, spark: => SparkSession): Unit = {
val mc = mosaicContext
val sc = spark
import mc.functions._
import sc.implicits._
mosaicContext.register(spark)

val referenceGeoms: immutable.Seq[MosaicGeometry] =
mocks.wkt_rows.sortBy(_.head.asInstanceOf[Int])
.map(_(1).asInstanceOf[String])
.map(mc.getGeometryAPI.geometry(_, "WKT"))

val expected = referenceGeoms.map(_.buffer(1).getLength)
val result = mocks.getWKTRowsDf
.orderBy("id")
.select(st_length(st_buffer($"wkt", lit(1))))
.as[Double]
.collect()

result.zip(expected).foreach { case (l, r) => math.abs(l - r) should be < 1e-8 }

mocks.getWKTRowsDf.createOrReplaceTempView("source")

val sqlResult = spark
.sql("select id, st_length(st_buffer(wkt, 1)) from source")
.orderBy("id")
.drop("id")
.as[Double]
.collect()

sqlResult.zip(expected).foreach { case (l, r) => math.abs(l - r) should be < 1e-8 }
}

def bufferCodegen(mosaicContext: => MosaicContext, spark: => SparkSession): Unit = {
val mc = mosaicContext
val sc = spark
import mc.functions._
import sc.implicits._
mosaicContext.register(spark)

val result = mocks.getWKTRowsDf
.select(st_length(st_buffer($"wkt", lit(1))))

val queryExecution = result.queryExecution
val plan = queryExecution.executedPlan

val wholeStageCodegenExec = plan.find(_.isInstanceOf[WholeStageCodegenExec])

wholeStageCodegenExec.isDefined shouldBe true

val codeGenStage = wholeStageCodegenExec.get.asInstanceOf[WholeStageCodegenExec]
val (_, code) = codeGenStage.doCodeGen()

noException should be thrownBy CodeGenerator.compile(code)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,9 @@ class TestGeometryProcessors extends AnyFlatSpec with GeometryProcessorsBehavior
it should behave like convexHullGeneration(MosaicContext.build(H3IndexSystem, JTS), spark)
}

"ST_Buffer" should "compute the buffer geometry for any geometry API" in {
it should behave like bufferCalculation(MosaicContext.build(H3IndexSystem, ESRI), spark)
it should behave like bufferCalculation(MosaicContext.build(H3IndexSystem, JTS), spark)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,11 @@ class TestGeometryProcessorsCodegen extends AnyFlatSpec with GeometryProcessorsB
it should behave like transformationsCodegen(MosaicContext.build(H3IndexSystem, JTS), spark)
}

"ST_ buffer" should "execute without errors for any index system and any geometry API" in {
it should behave like bufferCalculation(MosaicContext.build(H3IndexSystem, ESRI), spark)
it should behave like bufferCalculation(MosaicContext.build(H3IndexSystem, JTS), spark)
it should behave like bufferCodegen(MosaicContext.build(H3IndexSystem, ESRI), spark)
it should behave like bufferCodegen(MosaicContext.build(H3IndexSystem, JTS), spark)
}

}

0 comments on commit 5665387

Please sign in to comment.