From 6fce1afee174a7a834b12b286bc3392e41e35244 Mon Sep 17 00:00:00 2001 From: Jing Chen He Date: Sat, 15 Dec 2018 08:41:16 -0600 Subject: [PATCH] [SPARK-26315][PYSPARK] auto cast threshold from Integer to Float in approxSimilarityJoin of BucketedRandomProjectionLSHModel ## What changes were proposed in this pull request? If the input parameter 'threshold' to the function approxSimilarityJoin is not a float, we would get an exception. The fix is to convert the 'threshold' into a float before calling the java implementation method. ## How was this patch tested? Added a new test case. Without this fix, the test will throw an exception as reported in the JIRA. With the fix, the test passes. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23313 from jerryjch/SPARK-26315. Authored-by: Jing Chen He Signed-off-by: Sean Owen --- python/pyspark/ml/feature.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c9507c20918e3..08ae58246adb6 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -192,6 +192,7 @@ def approxSimilarityJoin(self, datasetA, datasetB, threshold, distCol="distCol") "datasetA" and "datasetB", and a column "distCol" is added to show the distance between each pair. """ + threshold = TypeConverters.toFloat(threshold) return self._call_java("approxSimilarityJoin", datasetA, datasetB, threshold, distCol) @@ -239,6 +240,16 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp | 3| 6| 2.23606797749979| +---+---+-----------------+ ... + >>> model.approxSimilarityJoin(df, df2, 3, distCol="EuclideanDistance").select( + ... col("datasetA.id").alias("idA"), + ... col("datasetB.id").alias("idB"), + ... col("EuclideanDistance")).show() + +---+---+-----------------+ + |idA|idB|EuclideanDistance| + +---+---+-----------------+ + | 3| 6| 2.23606797749979| + +---+---+-----------------+ + ... >>> brpPath = temp_path + "/brp" >>> brp.save(brpPath) >>> brp2 = BucketedRandomProjectionLSH.load(brpPath)