Skip to content

Commit

Permalink
Add black and flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
Jelmer Kuperus committed Nov 3, 2024
1 parent 708ff5e commit e3623e0
Show file tree
Hide file tree
Showing 15 changed files with 417 additions and 165 deletions.
6 changes: 6 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[flake8]
exclude =
.git,
__pycache__,
.pytest_cache
max-line-length = 120
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
3.9
- name: Build and test
run: |
sbt -java-home "$JAVA_HOME_8_X64" clean scalafmtCheckAll +test -DsparkVersion="$SPARK_VERSION"
sbt -java-home "$JAVA_HOME_8_X64" clean scalafmtCheckAll blackCheck flake8 +test pyTest -DsparkVersion="$SPARK_VERSION"
- name: Publish Unit test results
uses: mikepenz/action-junit-report@v4
with:
Expand Down
47 changes: 39 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,24 @@ lazy val noPublishSettings =
val hnswLibVersion = "1.1.2"
val sparkVersion = settingKey[String]("Spark version")

lazy val pyTest = taskKey[Unit]("Run the python tests")
lazy val createVirtualEnv = taskKey[Unit]("Create venv")
lazy val pyTest = taskKey[Unit]("Run the python tests")
lazy val black = taskKey[Unit]("Run the black code formatter")
lazy val blackCheck = taskKey[Unit]("Run the black code formatter in check mode")
lazy val flake8 = taskKey[Unit]("Run the flake8 style enforcer")

lazy val root = (project in file("."))
.aggregate(hnswlibSpark)
.settings(noPublishSettings)

lazy val pythonVersion = Def.setting {
if (scalaVersion.value == "2.11.12") "python3.7" else "python3.9"
}

lazy val venvFolder = Def.setting {
s"${baseDirectory.value}/.venv"
}

lazy val hnswlibSpark = (project in file("hnswlib-spark"))
.settings(
name := s"hnswlib-spark_${sparkVersion.value.split('.').take(2).mkString("_")}",
Expand Down Expand Up @@ -78,28 +90,47 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark"))
_.withIncludeScala(false)
},
sparkVersion := sys.props.getOrElse("sparkVersion", "3.3.2"),
createVirtualEnv := {
val ret = (
s"${pythonVersion.value} -m venv ${venvFolder.value}" #&&
s"${venvFolder.value}/bin/pip install wheel==0.42.0 pytest==7.4.3 pyspark[ml]==${sparkVersion.value} black==23.3.0 flake8==5.0.4"
).!
require(ret == 0, "Creating venv failed")
},
pyTest := {
val log = streams.value.log

val artifactPath = (Compile / assembly).value.getAbsolutePath
val venv = venvFolder.value

if (scalaVersion.value == "2.12.18" && sparkVersion.value >= "3.0.0" || scalaVersion.value == "2.11.12") {
val pythonVersion = if (scalaVersion.value == "2.11.12") "python3.7" else "python3.9"
val ret = Process(
Seq("./run-pyspark-tests.sh", sparkVersion.value, pythonVersion),
Seq(s"$venv/bin/pytest", "--junitxml=target/test-reports/TEST-python.xml", "src/test/python"),
cwd = baseDirectory.value,
extraEnv = "ARTIFACT_PATH" -> artifactPath
extraEnv = "ARTIFACT_PATH" -> artifactPath, "PYTHONPATH" -> s"${baseDirectory.value}/src/main/python"
).!
require(ret == 0, "Python tests failed")
} else {
// pyspark packages support just one version of scala. You cannot use 2.13.x because it ships with 2.12.x jars
log.info(s"Running pyTests for Scala ${scalaVersion.value} and Spark ${sparkVersion.value} is not supported.")
}
},
test := {
(Test / test).value
(Test / pyTest).value
pyTest := pyTest.dependsOn(assembly, createVirtualEnv).value,
blackCheck := {
val ret = s"${venvFolder.value}/bin/black --check ${baseDirectory.value}/src/main/python".!
require(ret == 0, "Black failed")
},
blackCheck := blackCheck.dependsOn(createVirtualEnv).value,
black := {
val ret = s"${venvFolder.value}/bin/black ${baseDirectory.value}/src/main/python".!
require(ret == 0, "Black failed")
},
black := black.dependsOn(createVirtualEnv).value,
flake8 := {
val ret = s"${venvFolder.value}/bin/flake8 ${baseDirectory.value}/src/main/python".!
require(ret == 0, "Flake8 failed")
},
pyTest := pyTest.dependsOn(assembly).value,
flake8 := flake8.dependsOn(createVirtualEnv).value,
libraryDependencies ++= Seq(
"com.github.jelmerk" % "hnswlib-utils" % hnswLibVersion,
"com.github.jelmerk" % "hnswlib-core-jdk17" % hnswLibVersion,
Expand Down
25 changes: 0 additions & 25 deletions hnswlib-spark/run-pyspark-tests.sh

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import pyspark_hnsw.conversion

sys.modules['com.github.jelmerk.spark.conversion'] = pyspark_hnsw.conversion
sys.modules["com.github.jelmerk.spark.conversion"] = pyspark_hnsw.conversion
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import pyspark_hnsw.knn

sys.modules['com.github.jelmerk.spark.knn.bruteforce'] = pyspark_hnsw.knn
sys.modules["com.github.jelmerk.spark.knn.bruteforce"] = pyspark_hnsw.knn
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import pyspark_hnsw.evaluation

sys.modules['com.github.jelmerk.spark.knn.evaluation'] = pyspark_hnsw.evaluation
sys.modules["com.github.jelmerk.spark.knn.evaluation"] = pyspark_hnsw.evaluation
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import pyspark_hnsw.knn

sys.modules['com.github.jelmerk.spark.knn.hnsw'] = pyspark_hnsw.knn
sys.modules["com.github.jelmerk.spark.knn.hnsw"] = pyspark_hnsw.knn
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import pyspark_hnsw.linalg

sys.modules['com.github.jelmerk.spark.linalg'] = pyspark_hnsw.linalg
sys.modules["com.github.jelmerk.spark.linalg"] = pyspark_hnsw.linalg
73 changes: 41 additions & 32 deletions hnswlib-spark/src/main/python/pyspark_hnsw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from pyspark.java_gateway import launch_gateway


def start(spark23=False,
spark24=False,
spark31=False,
memory="16G",
cache_folder="/tmp",
real_time_output=False,
output_level=1):
def start(
spark23=False,
spark24=False,
spark31=False,
memory="16G",
cache_folder="/tmp",
real_time_output=False,
output_level=1,
):
"""Starts a PySpark instance with default parameters for Hnswlib.
The default parameters would result in the equivalent of:
Expand Down Expand Up @@ -62,7 +64,6 @@ def start(spark23=False,
current_version = "1.1.0"

class HnswlibConfig:

def __init__(self):
self.master = "local[*]"
self.app_name = "Hnswlib"
Expand All @@ -71,20 +72,27 @@ def __init__(self):
# Hnswlib on Apache Spark 3.2.x

# Hnswlib on Apache Spark 3.0.x/3.1.x
self.maven_spark = "com.github.jelmerk:hnswlib-spark_3.1_2.12:{}".format(current_version)
self.maven_spark = "com.github.jelmerk:hnswlib-spark_3.1_2.12:{}".format(
current_version
)
# Hnswlib on Apache Spark 2.4.x
self.maven_spark24 = "com.github.jelmerk:hnswlib-spark_2.4_2.12:{}".format(current_version)
self.maven_spark24 = "com.github.jelmerk:hnswlib-spark_2.4_2.12:{}".format(
current_version
)
# Hnswlib on Apache Spark 2.3.x
self.maven_spark23 = "com.github.jelmerk:hnswlib-spark_2.3_2.11:{}".format(current_version)
self.maven_spark23 = "com.github.jelmerk:hnswlib-spark_2.3_2.11:{}".format(
current_version
)

def start_without_realtime_output():
builder = SparkSession.builder \
.appName(spark_nlp_config.app_name) \
.master(spark_nlp_config.master) \
.config("spark.driver.memory", memory) \
.config("spark.serializer", spark_nlp_config.serializer) \
.config("spark.kryo.registrator", spark_nlp_config.registrator) \
builder = (
SparkSession.builder.appName(spark_nlp_config.app_name)
.master(spark_nlp_config.master)
.config("spark.driver.memory", memory)
.config("spark.serializer", spark_nlp_config.serializer)
.config("spark.kryo.registrator", spark_nlp_config.registrator)
.config("spark.hnswlib.settings.index.cache_folder", cache_folder)
)

if spark23:
builder.config("spark.jars.packages", spark_nlp_config.maven_spark23)
Expand All @@ -96,9 +104,7 @@ def start_without_realtime_output():
return builder.getOrCreate()

def start_with_realtime_output():

class SparkWithCustomGateway:

def __init__(self):
spark_conf = SparkConf()
spark_conf.setAppName(spark_nlp_config.app_name)
Expand All @@ -107,17 +113,21 @@ def __init__(self):
spark_conf.set("spark.serializer", spark_nlp_config.serializer)
spark_conf.set("spark.kryo.registrator", spark_nlp_config.registrator)
spark_conf.set("spark.jars.packages", spark_nlp_config.maven_spark)
spark_conf.set("spark.hnswlib.settings.index.cache_folder", cache_folder)
spark_conf.set(
"spark.hnswlib.settings.index.cache_folder", cache_folder
)

# Make the py4j JVM stdout and stderr available without buffering
popen_kwargs = {
'stdout': subprocess.PIPE,
'stderr': subprocess.PIPE,
'bufsize': 0
"stdout": subprocess.PIPE,
"stderr": subprocess.PIPE,
"bufsize": 0,
}

# Launch the gateway with our custom settings
self.gateway = launch_gateway(conf=spark_conf, popen_kwargs=popen_kwargs)
self.gateway = launch_gateway(
conf=spark_conf, popen_kwargs=popen_kwargs
)
self.process = self.gateway.proc
# Use the gateway we launched
spark_context = SparkContext(gateway=self.gateway)
Expand All @@ -132,15 +142,15 @@ def std_background_listeners(self):
self.error_thread.start()

def output_reader(self):
for line in iter(self.process.stdout.readline, b''):
print('{0}'.format(line.decode('utf-8')), end='')
for line in iter(self.process.stdout.readline, b""):
print("{0}".format(line.decode("utf-8")), end="")

def error_reader(self):
RED = '\033[91m'
RESET = '\033[0m'
for line in iter(self.process.stderr.readline, b''):
RED = "\033[91m"
RESET = "\033[0m"
for line in iter(self.process.stderr.readline, b""):
if output_level == 0:
print(RED + '{0}'.format(line.decode('utf-8')) + RESET, end='')
print(RED + "{0}".format(line.decode("utf-8")) + RESET, end="")
else:
# output just info
pass
Expand All @@ -164,7 +174,6 @@ def shutdown(self):
else:
# Available from Spark 3.0.x
class SparkRealTimeOutput:

def __init__(self):
self.__spark_with_custom_gateway = start_with_realtime_output()
self.spark_session = self.__spark_with_custom_gateway.spark_session
Expand All @@ -186,4 +195,4 @@ def version():
str
The current Hnswlib version.
"""
return '1.1.0'
return "1.1.0"
34 changes: 27 additions & 7 deletions hnswlib-spark/src/main/python/pyspark_hnsw/conversion.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,52 @@
from pyspark.ml.param.shared import *
from pyspark.ml.param.shared import (
Params,
Param,
TypeConverters,
HasInputCol,
HasOutputCol,
)
from pyspark.ml.wrapper import JavaTransformer
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.mllib.common import inherit_doc

# noinspection PyProtectedMember
from pyspark import keyword_only

__all__ = ['VectorConverter']
__all__ = ["VectorConverter"]


# noinspection PyPep8Naming
@inherit_doc
class VectorConverter(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
class VectorConverter(
JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable
):
"""
Converts the input vector to a vector of another type.
"""

outputType = Param(Params._dummy(), "outputType", "type of vector to produce. one of array<float>, array<double>, vector",
typeConverter=TypeConverters.toString)
outputType = Param(
Params._dummy(),
"outputType",
"type of vector to produce. one of array<float>, array<double>, vector",
typeConverter=TypeConverters.toString,
)

@keyword_only
def __init__(self, inputCol="input", outputCol="output", outputType="array<float>"):
"""
__init__(self, inputCol="input", outputCol="output", outputType="array<float>")
"""
super(VectorConverter, self).__init__()
self._java_obj = self._new_java_obj("com.github.jelmerk.spark.conversion.VectorConverter", self.uid)
self._java_obj = self._new_java_obj(
"com.github.jelmerk.spark.conversion.VectorConverter", self.uid
)
kwargs = self._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCol="input", outputCol="output", outputType="array<float>"):
def setParams(
self, inputCol="input", outputCol="output", outputType="array<float>"
):
"""
setParams(self, inputCol="input", outputCol="output", outputType="array<float>")
Sets params for this VectorConverter.
Expand Down
Loading

0 comments on commit e3623e0

Please sign in to comment.