From 62d11af65d86299143e75e2bca8e0f4214e77300 Mon Sep 17 00:00:00 2001 From: Erni Durdevic Date: Thu, 14 Apr 2022 16:46:51 +0200 Subject: [PATCH] Fixed mosaicfill for column argument --- python/mosaic/api/functions.py | 12 ++++++++---- python/test/test_functions.py | 1 + .../labs/mosaic/functions/MosaicContext.scala | 8 ++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/python/mosaic/api/functions.py b/python/mosaic/api/functions.py index 511e8ccce..9896d1375 100644 --- a/python/mosaic/api/functions.py +++ b/python/mosaic/api/functions.py @@ -1,8 +1,8 @@ import inspect -from typing import overload +from typing import overload, Any from pyspark.sql import Column -from pyspark.sql.functions import col, _to_java_column as pyspark_to_java_column +from pyspark.sql.functions import lit, _to_java_column as pyspark_to_java_column from mosaic.config import config from mosaic.utils.types import ColumnOrName, as_typed_col @@ -547,7 +547,7 @@ def mosaic_explode(geom: ColumnOrName, resolution: ColumnOrName, keep_core_geome keep_core_geometries ) -def mosaicfill(geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: bool = True) -> Column: +def mosaicfill(geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True) -> Column: """ Generates: - a set of core indices that are fully contained by `geom`; and @@ -568,9 +568,13 @@ def mosaicfill(geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometrie if keep_core_geometries is set to False. """ + + if(type(keep_core_geometries) == bool): + keep_core_geometries = lit(keep_core_geometries) + return config.mosaic_context.invoke_function( "mosaicfill", pyspark_to_java_column(geom), pyspark_to_java_column(resolution), - keep_core_geometries + pyspark_to_java_column(keep_core_geometries) ) diff --git a/python/test/test_functions.py b/python/test/test_functions.py index 1645f1193..a2c973005 100644 --- a/python/test/test_functions.py +++ b/python/test/test_functions.py @@ -65,6 +65,7 @@ def test_st_bindings_happy_flow(self): .withColumn("mosaicfill", api.mosaicfill("wkt", lit(1))) .withColumn("mosaic_explode_no_core_chips", api.mosaic_explode("wkt", lit(1), False)) .withColumn("mosaicfill_no_core_chips", api.mosaicfill("wkt", lit(1), False)) + .withColumn("mosaicfill_no_core_chips_bool", api.mosaicfill("wkt", lit(1), lit(False))) ) self.assertEqual(result.count(), 1) diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 42b7d878f..66d62e49f 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -404,9 +404,13 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def mosaicfill(geom: Column, resolution: Int): Column = ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, lit(true).expr, indexSystem.name, geometryAPI.name)) def mosaicfill(geom: Column, resolution: Column, keepCoreGeometries: Boolean): Column = - ColumnAdapter(MosaicFill(geom.expr, resolution.expr, lit(keepCoreGeometries).expr, indexSystem.name, geometryAPI.name)) + ColumnAdapter(MosaicFill(geom.expr, resolution.expr, lit(keepCoreGeometries).expr, indexSystem.name, geometryAPI.name)) def mosaicfill(geom: Column, resolution: Int, keepCoreGeometries: Boolean): Column = - ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, lit(keepCoreGeometries).expr, indexSystem.name, geometryAPI.name)) + ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, lit(keepCoreGeometries).expr, indexSystem.name, geometryAPI.name)) + def mosaicfill(geom: Column, resolution: Column, keepCoreGeometries: Column): Column = + ColumnAdapter(MosaicFill(geom.expr, resolution.expr, keepCoreGeometries.expr, indexSystem.name, geometryAPI.name)) + def mosaicfill(geom: Column, resolution: Int, keepCoreGeometries: Column): Column = + ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, keepCoreGeometries.expr, indexSystem.name, geometryAPI.name)) def point_index_geom(point: Column, resolution: Column): Column = ColumnAdapter(PointIndexGeom(point.expr, resolution.expr, indexSystem.name, geometryAPI.name)) def point_index_geom(point: Column, resolution: Int): Column =