Skip to content

Commit

Permalink
Allow to install PIP packages into PySpark job (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi authored Dec 14, 2023
1 parent ac6f6cb commit 57aec01
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/actions/test-python/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ runs:
python -m pip install --upgrade pip
pip install pypandoc
pip install -r python/requirements-${{ inputs.spark-compat-version }}_${{ inputs.scala-compat-version }}.txt
pip install pytest unittest-xml-reporting
pip install -r python/test/requirements.txt
SPARK_HOME=$(python -c "import pyspark; import os; print(os.path.dirname(pyspark.__file__))")
echo "SPARK_HOME=$SPARK_HOME" | tee -a "$GITHUB_ENV"
Expand Down
36 changes: 36 additions & 0 deletions PYSPARK-DEPS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# PySpark dependencies

Using PySpark on a cluster requires all cluster nodes to have those Python packages installed that are required by the PySpark job.
Such a deployment can be cumbersome, especially when running in an interactive notebook.

The `spark-extension` package allows installing Python packages programmatically by the PySpark application itself (PySpark ≥ 3.1.0).
These packages are only accessible by that PySpark application, and they are removed on calling `spark.stop()`.

```python
# noinspection PyUnresolvedReferences
from gresearch.spark import *

spark.install_pip_package("pandas", "pyarrow")
```

Above example installs PIP packages `pandas` and `pyarrow` via `pip`. Method `install_pip_package` takes any `pip` command line argument:

```python
# install packages with version specs
spark.install_pip_package("pandas==1.4.3", "pyarrow~=8.0.0")

# install packages from package sources (e.g. git clone https://github.com/pandas-dev/pandas.git)
spark.install_pip_package("./pandas/")

# install packages from git repo
spark.install_pip_package("git+https://github.com/pandas-dev/pandas.git@main")

# use a pip cache directory to cache downloaded whl files
spark.install_pip_package("pandas", "pyarrow", "--cache-dir", "/home/user/.cache/pip")

# use an alternative index url (other than https://pypi.org/simple)
spark.install_pip_package("pandas", "pyarrow", "--index-url", "https://artifacts.company.com/pypi/simple")

# install pip packages quietly (only disables output of PIP)
spark.install_pip_package("pandas", "pyarrow", "--quiet")
```
41 changes: 40 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ efficiently laid out with a single operation.
or [parquet-cli](https://pypi.org/project/parquet-cli/) by reading from a simple Spark data source.
This simplifies identifying why some Parquet files cannot be split by Spark into scalable partitions.

**[Install PIP packages into PySpark job](PYSPARK-DEPS.md):** Install Python dependencies via PIP directly into your running PySpark job (PySpark ≥ 3.1.0):

```python
# noinspection PyUnresolvedReferences
from gresearch.spark import *

spark.install_pip_package("pandas==1.4.3", "pyarrow")
spark.install_pip_package("-r", "requirements.txt")
```

**[Fluent method call](CONDITIONAL.md):** `T.call(transformation: T => R): R`: Turns a transformation `T => R`,
that is not part of `T` into a fluent method call on `T`. This allows writing fluent code like:

Expand Down Expand Up @@ -52,13 +62,17 @@ should be preferred over calling `Dataset.groupByKey(V => K)` whenever possible.
existing partitioning and ordering of the Dataset, while the latter hides from Catalyst which columns are used to create the keys.
This can have a significant performance penalty.

<details>
<summary>Details:</summary>

The new column-expression-based `groupByKey[K](Column*)` method makes it easier to group by a column expression key. Instead of

ds.groupBy($"id").as[Int, V]

use:

ds.groupByKey[Int]($"id")
</details>

**Backticks:** `backticks(string: String, strings: String*): String)`: Encloses the given column name with backticks (`` ` ``) when needed.
This is a handy way to ensure column names with special characters like dots (`.`) work with `col()` or `select()`.
Expand Down Expand Up @@ -99,7 +113,31 @@ unix_epoch_nanos_to_dotnet_ticks(column_or_name)
```
</details>

**Spark job description:** Set Spark job description for all Spark jobs within a context:
**Spark temporary directory**: Create a temporary directory that will be removed on Spark application shutdown.

<details>
<summary>Examples:</summary>

Scala:
```scala
import uk.co.gresearch.spark.createTemporaryDir

val dir = createTemporaryDir("prefix")
```

Python:
```python
# noinspection PyUnresolvedReferences
from gresearch.spark import *

dir = spark.create_temporary_dir("prefix")
```
</details>

**Spark job description:** Set Spark job description for all Spark jobs within a context.

<details>
<summary>Examples:</summary>

```scala
import uk.co.gresearch.spark._
Expand Down Expand Up @@ -140,6 +178,7 @@ val counts = withJobDescription("Counting rows") {
files.map(filename => spark.read.csv(filename).count).sum
}(spark)
```
</details>

## Using Spark Extension

Expand Down
57 changes: 57 additions & 0 deletions python/gresearch/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import sys
import time
from contextlib import contextmanager
from typing import Any, Union, List, Optional, Mapping, TYPE_CHECKING
import subprocess

from py4j.java_gateway import JVMView, JavaObject
from pyspark import __version__
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.sql import DataFrame
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.context import SQLContext
Expand Down Expand Up @@ -405,3 +412,53 @@ def append_job_description(extra_description: str, separator: str = " - "):
yield
finally:
set_description(earlier)


def create_temporary_dir(spark: Union[SparkSession, SparkContext], prefix: str) -> str:
"""
Create a temporary directory in a location (driver temp dir) that will be deleted on Spark application shutdown.
:param spark: spark session or context
:param prefix: prefix string of temporary directory name
:return: absolute path of temporary directory
"""
if isinstance(spark, SparkSession):
spark = spark.sparkContext

package = spark._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$")
mktempdir = package.createTemporaryDir
return mktempdir(prefix)


SparkSession.create_temporary_dir = create_temporary_dir
SparkContext.create_temporary_dir = create_temporary_dir


def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pip_option: str) -> None:
if __version__.startswith('2.') or __version__.startswith('3.0.'):
raise NotImplementedError(f'Not supported for PySpark __version__')

if isinstance(spark, SparkSession):
spark = spark.sparkContext

# create temporary directory for packages, inside a directory which will be deleted on spark application shutdown
id = f"spark-extension-pip-pkgs-{time.time()}"
dir = spark.create_temporary_dir(f"{id}-")

# install packages via pip install
# it is best to run pip as a separate process and not calling into module pip
# https://pip.pypa.io/en/stable/user_guide/#using-pip-from-your-program
subprocess.check_call([sys.executable, '-m', 'pip', "install"] + list(package_or_pip_option) + ["--target", dir])

# zip packages and remove directory
zip = shutil.make_archive(dir, "zip", dir)
shutil.rmtree(dir)

# register zip file as archive, and add as python source
# once support for Spark 3.0 is dropped, replace with spark.addArchive()
spark._jsc.sc().addArchive(zip + "#" + id)
spark._python_includes.append(id)
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), id))


SparkSession.install_pip_package = install_pip_package
SparkContext.install_pip_package = install_pip_package
4 changes: 4 additions & 0 deletions python/test/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pandas
pyarrow
pytest
unittest-xml-reporting
50 changes: 49 additions & 1 deletion python/test/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from decimal import Decimal
from subprocess import CalledProcessError
from unittest import skipUnless, skipIf

from pyspark import __version__
from pyspark.sql import Row
from pyspark.sql.functions import col, count

from gresearch.spark import dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \
timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, count_null
from spark_common import SparkTest
from decimal import Decimal


class PackageTest(SparkTest):
Expand Down Expand Up @@ -151,6 +155,50 @@ def test_count_null(self):
).collect()
self.assertEqual([Row(ids=7, nanos=6, null_ids=0, null_nanos=1)], actual)

def test_create_temp_dir(self):
from pyspark import SparkFiles

dir = self.spark.create_temporary_dir("prefix")
self.assertTrue(dir.startswith(SparkFiles.getRootDirectory()))

@skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
def test_install_pip_package(self):
self.spark.sparkContext.setLogLevel("INFO")
with self.assertRaises(ImportError):
# noinspection PyPackageRequirements
import emoji
emoji.emojize("this test is :thumbs_up:")

self.spark.install_pip_package("emoji")

# noinspection PyPackageRequirements
import emoji
actual = emoji.emojize("this test is :thumbs_up:")
expected = "this test is 👍"
self.assertEqual(expected, actual)

import pandas as pd
actual = self.spark.range(0, 10, 1, 10) \
.mapInPandas(lambda it: [pd.DataFrame.from_dict({"val": [emoji.emojize(":thumbs_up:")]})], "val string") \
.collect()
expected = [Row("👍")] * 10
self.assertEqual(expected, actual)

@skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
def test_install_pip_package_unknown_argument(self):
with self.assertRaises(CalledProcessError):
self.spark.install_pip_package("--unknown", "argument")

@skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
def test_install_pip_package_package_not_found(self):
with self.assertRaises(CalledProcessError):
self.spark.install_pip_package("pyspark-extension==abc")

@skipUnless(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
def test_install_pip_package_not_supported(self):
with self.assertRaises(NotImplementedError):
self.spark.install_pip_package("emoji")


if __name__ == '__main__':
SparkTest.main()
14 changes: 13 additions & 1 deletion src/main/scala/uk/co/gresearch/spark/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@

package uk.co.gresearch

import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.functions.{col, count, lit, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, LongType, TimestampType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{SparkContext, SparkFiles}
import uk.co.gresearch.spark.group.SortedGroupByDataset

import java.nio.file.{Files, Paths}

package object spark extends Logging with SparkVersion with BuildVersion {

/**
Expand All @@ -37,6 +39,16 @@ package object spark extends Logging with SparkVersion with BuildVersion {
"_" * (existing.map(_.takeWhile(_ == '_').length).reduceOption(_ max _).getOrElse(0) + 1)
}

/**
* Create a temporary directory in a location (driver temp dir) that will be deleted on Spark application shutdown.
* @param prefix prefix string of temporary directory name
* @return absolute path of temporary directory
*/
def createTemporaryDir(prefix: String): String = {
// SparkFiles.getRootDirectory() will be deleted on spark application shutdown
Files.createTempDirectory(Paths.get(SparkFiles.getRootDirectory()), prefix).toAbsolutePath.toString
}

// https://issues.apache.org/jira/browse/SPARK-40588
private[spark] def writePartitionedByRequiresCaching[T](ds: Dataset[T]): Boolean = {
ds.sparkSession.conf.get(
Expand Down
11 changes: 9 additions & 2 deletions src/test/scala/uk/co/gresearch/spark/SparkSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@

package uk.co.gresearch.spark

import org.apache.spark.TaskContext
import org.apache.spark.{SparkFiles, TaskContext}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{Descending, SortOrder}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.{DISK_ONLY, MEMORY_AND_DISK, MEMORY_ONLY, OFF_HEAP, NONE}
import org.apache.spark.storage.StorageLevel.{DISK_ONLY, MEMORY_AND_DISK, MEMORY_ONLY, NONE, OFF_HEAP}
import org.scalatest.funsuite.AnyFunSuite
import uk.co.gresearch.ExtendedAny
import uk.co.gresearch.spark.SparkSuite.{Value, collectJobDescription}

import java.nio.file.Paths
import java.sql.Timestamp
import java.time.Instant

Expand Down Expand Up @@ -697,6 +698,12 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
}
}
}

test("Spark temp dir") {
import uk.co.gresearch.spark.createTemporaryDir
val dir = createTemporaryDir("test")
assert(Paths.get(dir).toAbsolutePath.toString.startsWith(SparkFiles.getRootDirectory()))
}
}

object SparkSuite {
Expand Down

0 comments on commit 57aec01

Please sign in to comment.