Skip to content

Commit

Permalink
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions i…
Browse files Browse the repository at this point in the history
…n PySpark tests (to skip or test)

## What changes were proposed in this pull request?

This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.

We declared the extra dependencies:

https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204

In case of PyArrow:

Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:

```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
    f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
  File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
    return _create_udf(f=f, returnType=return_type, evalType=eval_type)
  File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
    require_minimum_pyarrow_version()
  File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
    "however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.

----------------------------------------------------------------------
Ran 33 tests in 8.098s

FAILED (errors=33)
```

In case of Pandas:

There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.

## How was this patch tested?

Manually tested by modifying the condition:

```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```

```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```

```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```

```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```

Author: hyukjinkwon <gurwls223@gmail.com>

Closes apache#20487 from HyukjinKwon/pyarrow-pandas-skip.
  • Loading branch information
HyukjinKwon authored and Robert Kruszewski committed Feb 12, 2018
1 parent 10add8b commit d824d5a
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 48 deletions.
4 changes: 4 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@
<paranamer.version>2.8</paranamer.version>
<maven-antrun.version>1.8</maven-antrun.version>
<commons-crypto.version>1.0.0</commons-crypto.version>
<!--
If you are changing Arrow version specification, please check ./python/pyspark/sql/utils.py,
./python/run-tests.py and ./python/setup.py too.
-->
<arrow.version>0.8.0</arrow.version>

<test.java.home>${java.home}</test.java.home>
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,9 @@ def toPandas(self):
0 2 Alice
1 5 Bob
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()

import pandas as pd

if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()

if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
== "true":
timezone = self.conf.get("spark.sql.session.timeZone")
Expand Down
87 changes: 48 additions & 39 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,26 @@
else:
import unittest

_have_pandas = False
_have_old_pandas = False
_pandas_requirement_message = None
try:
import pandas
try:
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
_have_pandas = True
except:
_have_old_pandas = True
except:
# No Pandas, but that's okay, we'll skip those tests
pass
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
except ImportError as e:
from pyspark.util import _exception_message
# If Pandas version requirement is not satisfied, skip related tests.
_pandas_requirement_message = _exception_message(e)

_pyarrow_requirement_message = None
try:
from pyspark.sql.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
except ImportError as e:
from pyspark.util import _exception_message
# If Arrow version requirement is not satisfied, skip related tests.
_pyarrow_requirement_message = _exception_message(e)

_have_pandas = _pandas_requirement_message is None
_have_pyarrow = _pyarrow_requirement_message is None

from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
Expand All @@ -75,15 +82,6 @@
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException


_have_arrow = False
try:
import pyarrow
_have_arrow = True
except:
# No Arrow, but that's okay, we'll skip those tests
pass


class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
Expand Down Expand Up @@ -2794,7 +2792,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"):

def _to_pandas(self):
from datetime import datetime, date
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
.add("c", BooleanType()).add("d", FloatType())\
.add("dt", DateType()).add("ts", TimestampType())
Expand All @@ -2807,7 +2804,7 @@ def _to_pandas(self):
df = self.spark.createDataFrame(data, schema)
return df.toPandas()

@unittest.skipIf(not _have_pandas, "Pandas not installed")
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_to_pandas(self):
import numpy as np
pdf = self._to_pandas()
Expand All @@ -2819,13 +2816,13 @@ def test_to_pandas(self):
self.assertEquals(types[4], np.object) # datetime.date
self.assertEquals(types[5], 'datetime64[ns]')

@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
def test_to_pandas_old(self):
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
def test_to_pandas_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self._to_pandas()

@unittest.skipIf(not _have_pandas, "Pandas not installed")
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_to_pandas_avoid_astype(self):
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
Expand All @@ -2843,7 +2840,7 @@ def test_create_dataframe_from_array_of_long(self):
df = self.spark.createDataFrame(data)
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))

@unittest.skipIf(not _have_pandas, "Pandas not installed")
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_create_dataframe_from_pandas_with_timestamp(self):
import pandas as pd
from datetime import datetime
Expand All @@ -2858,14 +2855,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self):
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))

@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
def test_create_dataframe_from_old_pandas(self):
import pandas as pd
from datetime import datetime
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
"d": [pd.Timestamp.now().date()]})
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
def test_create_dataframe_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
with self.assertRaisesRegexp(
ImportError,
'(Pandas >= .* must be installed|No module named pandas)'):
import pandas as pd
from datetime import datetime
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
"d": [pd.Timestamp.now().date()]})
self.spark.createDataFrame(pdf)


Expand Down Expand Up @@ -3383,7 +3382,9 @@ def __init__(self, **kwargs):
_make_type_verifier(data_type, nullable=False)(obj)


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class ArrowTests(ReusedSQLTestCase):

@classmethod
Expand Down Expand Up @@ -3641,7 +3642,9 @@ def test_createDataFrame_with_int_col_names(self):
self.assertEqual(pdf_col_names, df_arrow.columns)


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class PandasUDFTests(ReusedSQLTestCase):
def test_pandas_udf_basic(self):
from pyspark.rdd import PythonEvalType
Expand Down Expand Up @@ -3765,7 +3768,9 @@ def foo(k, v):
return k


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class ScalarPandasUDFTests(ReusedSQLTestCase):

@classmethod
Expand Down Expand Up @@ -4278,7 +4283,9 @@ def test_register_vectorized_udf_basic(self):
self.assertEquals(expected.collect(), res2.collect())


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class GroupedMapPandasUDFTests(ReusedSQLTestCase):

@property
Expand Down Expand Up @@ -4447,7 +4454,9 @@ def test_unsupported_types(self):
df.groupby('id').apply(f).collect()


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class GroupedAggPandasUDFTests(ReusedSQLTestCase):

@property
Expand Down
30 changes: 22 additions & 8 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,32 @@ def toJArray(gateway, jtype, arr):
def require_minimum_pandas_version():
""" Raise ImportError if minimum version of Pandas is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
minimum_pandas_version = "0.19.2"

from distutils.version import LooseVersion
import pandas
if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'):
raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; "
"however, your version was %s." % pandas.__version__)
try:
import pandas
except ImportError:
raise ImportError("Pandas >= %s must be installed; however, "
"it was not found." % minimum_pandas_version)
if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
raise ImportError("Pandas >= %s must be installed; however, "
"your version was %s." % (minimum_pandas_version, pandas.__version__))


def require_minimum_pyarrow_version():
""" Raise ImportError if minimum version of pyarrow is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
minimum_pyarrow_version = "0.8.0"

from distutils.version import LooseVersion
import pyarrow
if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'):
raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; "
"however, your version was %s." % pyarrow.__version__)
try:
import pyarrow
except ImportError:
raise ImportError("PyArrow >= %s must be installed; however, "
"it was not found." % minimum_pyarrow_version)
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
raise ImportError("PyArrow >= %s must be installed; however, "
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
10 changes: 9 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ def _supports_symlinks():
file=sys.stderr)
exit(-1)

# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and
# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml.
_minimum_pandas_version = "0.19.2"
_minimum_pyarrow_version = "0.8.0"

try:
# We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts
# find it where expected. The rest of the files aren't copied because they are accessed
Expand Down Expand Up @@ -201,7 +206,10 @@ def _supports_symlinks():
extras_require={
'ml': ['numpy>=1.7'],
'mllib': ['numpy>=1.7'],
'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0']
'sql': [
'pandas>=%s' % _minimum_pandas_version,
'pyarrow>=%s' % _minimum_pyarrow_version,
]
},
classifiers=[
'Development Status :: 5 - Production/Stable',
Expand Down

0 comments on commit d824d5a

Please sign in to comment.