Skip to content

Commit

Permalink
[SPARK-26559][ML][PYSPARK] ML image can't work with numpy versions pr…
Browse files Browse the repository at this point in the history
…ior to 1.9

## What changes were proposed in this pull request?

Due to [API change](https://github.com/numpy/numpy/pull/4257/files#diff-c39521d89f7e61d6c0c445d93b62f7dc) at 1.9, PySpark image doesn't work with numpy version prior to 1.9.

When running image test with numpy version prior to 1.9, we can see error:
```
test_read_images (pyspark.ml.tests.test_image.ImageReaderTest) ... ERROR
test_read_images_multiple_times (pyspark.ml.tests.test_image.ImageReaderTest2) ... ok

======================================================================
ERROR: test_read_images (pyspark.ml.tests.test_image.ImageReaderTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/viirya/docker_tmp/repos/spark-1/python/pyspark/ml/tests/test_image.py", line 36, in test_read_images
    self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row)
  File "/Users/viirya/docker_tmp/repos/spark-1/python/pyspark/ml/image.py", line 193, in toImage
    data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
AttributeError: 'numpy.ndarray' object has no attribute 'tobytes'

----------------------------------------------------------------------
Ran 2 tests in 29.040s

FAILED (errors=1)
```

## How was this patch tested?

Manually test with numpy version prior and after 1.9.

Closes #23484 from viirya/fix-pyspark-image.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
viirya authored and HyukjinKwon committed Jan 7, 2019
1 parent 468d25e commit a927c76
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/pyspark/ml/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings

import numpy as np
from distutils.version import LooseVersion

from pyspark import SparkContext
from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string
Expand Down Expand Up @@ -190,7 +191,11 @@ def toImage(self, array, origin=""):
# Running `bytearray(numpy.array([1]))` fails in specific Python versions
# with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
# Here, it avoids it by converting it to bytes.
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
if LooseVersion(np.__version__) >= LooseVersion('1.9'):
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
else:
# Numpy prior to 1.9 don't have `tobytes` method.
data = bytearray(array.astype(dtype=np.uint8).ravel())

# Creating new Row with _create_row(), because Row(name = value, ... )
# orders fields by name, which conflicts with expected schema order
Expand Down

0 comments on commit a927c76

Please sign in to comment.