diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..ed6cb49d --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +[flake8] +exclude = + .git, + __pycache__, + .pytest_cache +max-line-length = 120 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 828a5afd..a71424ec 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,7 +43,7 @@ jobs: 3.9 - name: Build and test run: | - sbt -java-home "$JAVA_HOME_8_X64" clean scalafmtCheckAll +test -DsparkVersion="$SPARK_VERSION" + sbt -java-home "$JAVA_HOME_8_X64" clean scalafmtCheckAll blackCheck flake8 +test -DsparkVersion="$SPARK_VERSION" - name: Publish Unit test results uses: mikepenz/action-junit-report@v4 with: diff --git a/build.sbt b/build.sbt index 84c5b0d0..17a9da18 100644 --- a/build.sbt +++ b/build.sbt @@ -43,12 +43,24 @@ lazy val noPublishSettings = val hnswLibVersion = "1.1.2" val sparkVersion = settingKey[String]("Spark version") -lazy val pyTest = taskKey[Unit]("Run the python tests") +lazy val createVirtualEnv = taskKey[Unit]("Create venv") +lazy val pyTest = taskKey[Unit]("Run the python tests") +lazy val black = taskKey[Unit]("Run the black code formatter") +lazy val blackCheck = taskKey[Unit]("Run the black code formatter in check mode") +lazy val flake8 = taskKey[Unit]("Run the flake8 style enforcer") lazy val root = (project in file(".")) .aggregate(hnswlibSpark) .settings(noPublishSettings) +lazy val pythonVersion = Def.setting { + if (scalaVersion.value == "2.11.12") "python3.7" else "python3.9" +} + +lazy val venvFolder = Def.setting { + s"${baseDirectory.value}/.venv" +} + lazy val hnswlibSpark = (project in file("hnswlib-spark")) .settings( name := s"hnswlib-spark_${sparkVersion.value.split('.').take(2).mkString("_")}", @@ -78,28 +90,48 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark")) _.withIncludeScala(false) }, sparkVersion := sys.props.getOrElse("sparkVersion", "3.3.2"), + createVirtualEnv := { + val ret = ( + s"${pythonVersion.value} -m venv ${venvFolder.value}" #&& + s"${venvFolder.value}/bin/pip install wheel==0.42.0 pytest==7.4.3 pyspark[ml]==${sparkVersion.value} black==23.3.0 flake8==5.0.4" + ).! + require(ret == 0, "Creating venv failed") + }, pyTest := { val log = streams.value.log val artifactPath = (Compile / assembly).value.getAbsolutePath + val venv = venvFolder.value + val python = pythonVersion.value + if (scalaVersion.value == "2.12.18" && sparkVersion.value >= "3.0.0" || scalaVersion.value == "2.11.12") { - val pythonVersion = if (scalaVersion.value == "2.11.12") "python3.7" else "python3.9" - val ret = Process( - Seq("./run-pyspark-tests.sh", sparkVersion.value, pythonVersion), - cwd = baseDirectory.value, - extraEnv = "ARTIFACT_PATH" -> artifactPath - ).! + val ret = Process(Seq(s"$venv/bin/pytest", "src/test/python"), cwd = baseDirectory.value, extraEnv = "ARTIFACT_PATH" -> artifactPath, "PYTHONPATH" -> s"${baseDirectory.value}/src/main/python", "PYSPARK_PYTHON" -> python).! require(ret == 0, "Python tests failed") } else { // pyspark packages support just one version of scala. You cannot use 2.13.x because it ships with 2.12.x jars log.info(s"Running pyTests for Scala ${scalaVersion.value} and Spark ${sparkVersion.value} is not supported.") } }, + pyTest := pyTest.dependsOn(assembly, createVirtualEnv).value, + blackCheck := { + val ret = s"${venvFolder.value}/bin/black --check ${baseDirectory.value}/src/main/python".! + require(ret == 0, "Black failed") + }, + blackCheck := blackCheck.dependsOn(createVirtualEnv).value, + black := { + val ret = s"${venvFolder.value}/bin/black ${baseDirectory.value}/src/main/python".! + require(ret == 0, "Black failed") + }, + black := black.dependsOn(createVirtualEnv).value, + flake8 := { + val ret = s"${venvFolder.value}/bin/flake8 ${baseDirectory.value}/src/main/python".! + require(ret == 0, "Flake8 failed") + }, + flake8 := flake8.dependsOn(createVirtualEnv).value, test := { (Test / test).value (Test / pyTest).value }, - pyTest := pyTest.dependsOn(assembly).value, libraryDependencies ++= Seq( "com.github.jelmerk" % "hnswlib-utils" % hnswLibVersion, "com.github.jelmerk" % "hnswlib-core-jdk17" % hnswLibVersion, diff --git a/hnswlib-spark/run-pyspark-tests.sh b/hnswlib-spark/run-pyspark-tests.sh deleted file mode 100755 index 20081f19..00000000 --- a/hnswlib-spark/run-pyspark-tests.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env bash - -set -e - -SPARK_VERSION=$1 -PYTHON_VERSION=$2 - -# add python sources on the path -export PYTHONPATH=src/main/python - -# unset SPARK_HOME or it will use whatever is configured on the host system instead of the pip packages -unset SPARK_HOME - -# create a virtual environment - -eval "$PYTHON_VERSION -m venv "target/spark-$SPARK_VERSION-venv"" -source "target/spark-$SPARK_VERSION-venv/bin/activate" - -# install packages -pip install wheel==0.42.0 -pip install pytest==7.4.3 -pip install 'pyspark[ml]'=="$SPARK_VERSION" - -# run unit tests -pytest --junitxml=target/test-reports/TEST-python.xml diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/conversion/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/conversion/__init__.py index d7734861..3ace0352 100644 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/conversion/__init__.py +++ b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/conversion/__init__.py @@ -1,4 +1,4 @@ import sys import pyspark_hnsw.conversion -sys.modules['com.github.jelmerk.spark.conversion'] = pyspark_hnsw.conversion +sys.modules["com.github.jelmerk.spark.conversion"] = pyspark_hnsw.conversion diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/bruteforce/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/bruteforce/__init__.py index 0393eccc..e8a13f43 100644 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/bruteforce/__init__.py +++ b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/bruteforce/__init__.py @@ -1,4 +1,4 @@ import sys import pyspark_hnsw.knn -sys.modules['com.github.jelmerk.spark.knn.bruteforce'] = pyspark_hnsw.knn +sys.modules["com.github.jelmerk.spark.knn.bruteforce"] = pyspark_hnsw.knn diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/evaluation/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/evaluation/__init__.py index c26a2188..a11f230d 100644 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/evaluation/__init__.py +++ b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/evaluation/__init__.py @@ -1,4 +1,4 @@ import sys import pyspark_hnsw.evaluation -sys.modules['com.github.jelmerk.spark.knn.evaluation'] = pyspark_hnsw.evaluation +sys.modules["com.github.jelmerk.spark.knn.evaluation"] = pyspark_hnsw.evaluation diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/hnsw/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/hnsw/__init__.py index b3e87b00..175e7374 100644 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/hnsw/__init__.py +++ b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/hnsw/__init__.py @@ -1,4 +1,4 @@ import sys import pyspark_hnsw.knn -sys.modules['com.github.jelmerk.spark.knn.hnsw'] = pyspark_hnsw.knn \ No newline at end of file +sys.modules["com.github.jelmerk.spark.knn.hnsw"] = pyspark_hnsw.knn diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/linalg/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/linalg/__init__.py index 6d0fdbe4..4cd8090b 100644 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/linalg/__init__.py +++ b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/linalg/__init__.py @@ -1,4 +1,4 @@ import sys import pyspark_hnsw.linalg -sys.modules['com.github.jelmerk.spark.linalg'] = pyspark_hnsw.linalg \ No newline at end of file +sys.modules["com.github.jelmerk.spark.linalg"] = pyspark_hnsw.linalg diff --git a/hnswlib-spark/src/main/python/pyspark_hnsw/__init__.py b/hnswlib-spark/src/main/python/pyspark_hnsw/__init__.py index 1e2d3598..cd64e8c5 100644 --- a/hnswlib-spark/src/main/python/pyspark_hnsw/__init__.py +++ b/hnswlib-spark/src/main/python/pyspark_hnsw/__init__.py @@ -6,13 +6,15 @@ from pyspark.java_gateway import launch_gateway -def start(spark23=False, - spark24=False, - spark31=False, - memory="16G", - cache_folder="/tmp", - real_time_output=False, - output_level=1): +def start( + spark23=False, + spark24=False, + spark31=False, + memory="16G", + cache_folder="/tmp", + real_time_output=False, + output_level=1, +): """Starts a PySpark instance with default parameters for Hnswlib. The default parameters would result in the equivalent of: @@ -62,7 +64,6 @@ def start(spark23=False, current_version = "1.1.0" class HnswlibConfig: - def __init__(self): self.master = "local[*]" self.app_name = "Hnswlib" @@ -71,20 +72,27 @@ def __init__(self): # Hnswlib on Apache Spark 3.2.x # Hnswlib on Apache Spark 3.0.x/3.1.x - self.maven_spark = "com.github.jelmerk:hnswlib-spark_3.1_2.12:{}".format(current_version) + self.maven_spark = "com.github.jelmerk:hnswlib-spark_3.1_2.12:{}".format( + current_version + ) # Hnswlib on Apache Spark 2.4.x - self.maven_spark24 = "com.github.jelmerk:hnswlib-spark_2.4_2.12:{}".format(current_version) + self.maven_spark24 = "com.github.jelmerk:hnswlib-spark_2.4_2.12:{}".format( + current_version + ) # Hnswlib on Apache Spark 2.3.x - self.maven_spark23 = "com.github.jelmerk:hnswlib-spark_2.3_2.11:{}".format(current_version) + self.maven_spark23 = "com.github.jelmerk:hnswlib-spark_2.3_2.11:{}".format( + current_version + ) def start_without_realtime_output(): - builder = SparkSession.builder \ - .appName(spark_nlp_config.app_name) \ - .master(spark_nlp_config.master) \ - .config("spark.driver.memory", memory) \ - .config("spark.serializer", spark_nlp_config.serializer) \ - .config("spark.kryo.registrator", spark_nlp_config.registrator) \ + builder = ( + SparkSession.builder.appName(spark_nlp_config.app_name) + .master(spark_nlp_config.master) + .config("spark.driver.memory", memory) + .config("spark.serializer", spark_nlp_config.serializer) + .config("spark.kryo.registrator", spark_nlp_config.registrator) .config("spark.hnswlib.settings.index.cache_folder", cache_folder) + ) if spark23: builder.config("spark.jars.packages", spark_nlp_config.maven_spark23) @@ -96,9 +104,7 @@ def start_without_realtime_output(): return builder.getOrCreate() def start_with_realtime_output(): - class SparkWithCustomGateway: - def __init__(self): spark_conf = SparkConf() spark_conf.setAppName(spark_nlp_config.app_name) @@ -107,17 +113,21 @@ def __init__(self): spark_conf.set("spark.serializer", spark_nlp_config.serializer) spark_conf.set("spark.kryo.registrator", spark_nlp_config.registrator) spark_conf.set("spark.jars.packages", spark_nlp_config.maven_spark) - spark_conf.set("spark.hnswlib.settings.index.cache_folder", cache_folder) + spark_conf.set( + "spark.hnswlib.settings.index.cache_folder", cache_folder + ) # Make the py4j JVM stdout and stderr available without buffering popen_kwargs = { - 'stdout': subprocess.PIPE, - 'stderr': subprocess.PIPE, - 'bufsize': 0 + "stdout": subprocess.PIPE, + "stderr": subprocess.PIPE, + "bufsize": 0, } # Launch the gateway with our custom settings - self.gateway = launch_gateway(conf=spark_conf, popen_kwargs=popen_kwargs) + self.gateway = launch_gateway( + conf=spark_conf, popen_kwargs=popen_kwargs + ) self.process = self.gateway.proc # Use the gateway we launched spark_context = SparkContext(gateway=self.gateway) @@ -132,15 +142,15 @@ def std_background_listeners(self): self.error_thread.start() def output_reader(self): - for line in iter(self.process.stdout.readline, b''): - print('{0}'.format(line.decode('utf-8')), end='') + for line in iter(self.process.stdout.readline, b""): + print("{0}".format(line.decode("utf-8")), end="") def error_reader(self): - RED = '\033[91m' - RESET = '\033[0m' - for line in iter(self.process.stderr.readline, b''): + RED = "\033[91m" + RESET = "\033[0m" + for line in iter(self.process.stderr.readline, b""): if output_level == 0: - print(RED + '{0}'.format(line.decode('utf-8')) + RESET, end='') + print(RED + "{0}".format(line.decode("utf-8")) + RESET, end="") else: # output just info pass @@ -164,7 +174,6 @@ def shutdown(self): else: # Available from Spark 3.0.x class SparkRealTimeOutput: - def __init__(self): self.__spark_with_custom_gateway = start_with_realtime_output() self.spark_session = self.__spark_with_custom_gateway.spark_session @@ -186,4 +195,4 @@ def version(): str The current Hnswlib version. """ - return '1.1.0' + return "1.1.0" diff --git a/hnswlib-spark/src/main/python/pyspark_hnsw/conversion.py b/hnswlib-spark/src/main/python/pyspark_hnsw/conversion.py index 1dcfd366..81cb7235 100644 --- a/hnswlib-spark/src/main/python/pyspark_hnsw/conversion.py +++ b/hnswlib-spark/src/main/python/pyspark_hnsw/conversion.py @@ -1,19 +1,35 @@ -from pyspark.ml.param.shared import * +from pyspark.ml.param.shared import ( + Params, + Param, + TypeConverters, + HasInputCol, + HasOutputCol, +) from pyspark.ml.wrapper import JavaTransformer from pyspark.ml.util import JavaMLReadable, JavaMLWritable from pyspark.mllib.common import inherit_doc + +# noinspection PyProtectedMember from pyspark import keyword_only -__all__ = ['VectorConverter'] +__all__ = ["VectorConverter"] + +# noinspection PyPep8Naming @inherit_doc -class VectorConverter(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorConverter( + JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable +): """ Converts the input vector to a vector of another type. """ - outputType = Param(Params._dummy(), "outputType", "type of vector to produce. one of array, array, vector", - typeConverter=TypeConverters.toString) + outputType = Param( + Params._dummy(), + "outputType", + "type of vector to produce. one of array, array, vector", + typeConverter=TypeConverters.toString, + ) @keyword_only def __init__(self, inputCol="input", outputCol="output", outputType="array"): @@ -21,12 +37,16 @@ def __init__(self, inputCol="input", outputCol="output", outputType="array