Skip to content

Commit

Permalink
Adds unit tests to some missing functions in h_spark
Browse files Browse the repository at this point in the history
So that we get some coverage going in case we accidentally break something.

Makes sure that we handle <3.9 appropriately too.
  • Loading branch information
skrawcz committed Aug 18, 2023
1 parent eae3edd commit 1807e6d
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 4 deletions.
117 changes: 116 additions & 1 deletion graph_adapter_tests/h_spark/test_h_spark.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import sys

import numpy as np
import pandas as pd
import pyspark.pandas as ps
import pytest
from pyspark import Row
from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, types
from pyspark.sql.functions import column

from hamilton import base, driver, htypes, node
Expand Down Expand Up @@ -235,3 +238,115 @@ def test_smoke_screen_udf_graph_adatper(spark_session):
Row(a=2, b=5, base_func=7, base_func2=11, base_func3=9),
Row(a=3, b=6, base_func=9, base_func2=13, base_func3=9),
]


# Test cases for python_to_spark_type function
@pytest.mark.parametrize(
"python_type,expected_spark_type",
[
(int, types.IntegerType()),
(float, types.FloatType()),
(bool, types.BooleanType()),
(str, types.StringType()),
(bytes, types.BinaryType()),
],
)
def test_python_to_spark_type_valid(python_type, expected_spark_type):
assert h_spark.python_to_spark_type(python_type) == expected_spark_type


@pytest.mark.parametrize("invalid_python_type", [list, dict, tuple, set])
def test_python_to_spark_type_invalid(invalid_python_type):
with pytest.raises(ValueError, match=f"Unsupported Python type: {invalid_python_type}"):
h_spark.python_to_spark_type(invalid_python_type)


# Test cases for get_spark_type function
# 1. Basic Python types
@pytest.mark.parametrize(
"return_type,expected_spark_type",
[
(int, types.IntegerType()),
(float, types.FloatType()),
(bool, types.BooleanType()),
(str, types.StringType()),
(bytes, types.BinaryType()),
],
)
def test_get_spark_type_basic_types(
dummy_kwargs, dummy_df, dummy_udf, return_type, expected_spark_type
):
assert (
h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, return_type)
== expected_spark_type
)


# 2. Lists of basic Python types
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 or higher")
@pytest.mark.parametrize(
"return_type,expected_spark_type",
[
(int, types.ArrayType(types.IntegerType())),
(float, types.ArrayType(types.FloatType())),
(bool, types.ArrayType(types.BooleanType())),
(str, types.ArrayType(types.StringType())),
(bytes, types.ArrayType(types.BinaryType())),
],
)
def test_get_spark_type_list_types(
dummy_kwargs, dummy_df, dummy_udf, return_type, expected_spark_type
):
return_type = list[return_type] # type: ignore
assert (
h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, return_type)
== expected_spark_type
)


# 3. Numpy types (assuming you have a numpy_to_spark_type function that handles these)
@pytest.mark.parametrize(
"return_type,expected_spark_type",
[
(np.int64, types.IntegerType()),
(np.float64, types.FloatType()),
(np.bool_, types.BooleanType()),
],
)
def test_get_spark_type_numpy_types(
dummy_kwargs, dummy_df, dummy_udf, return_type, expected_spark_type
):
assert (
h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, return_type)
== expected_spark_type
)


# 4. Unsupported types
@pytest.mark.parametrize(
"unsupported_return_type", [dict, set, tuple] # Add other unsupported types as needed
)
def test_get_spark_type_unsupported(dummy_kwargs, dummy_df, dummy_udf, unsupported_return_type):
with pytest.raises(
ValueError, match=f"Currently unsupported return type {unsupported_return_type}."
):
h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, unsupported_return_type)


# Dummy values for the tests
@pytest.fixture
def dummy_kwargs():
return {}


@pytest.fixture
def dummy_df():
return spark.createDataFrame(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))


@pytest.fixture
def dummy_udf():
def dummyfunc(x: int) -> int:
return x

return dummyfunc
13 changes: 10 additions & 3 deletions hamilton/experimental/h_spark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
import inspect
import logging
from typing import Any, Callable, Dict, Set, Tuple, Type, Union
import sys
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -178,7 +179,7 @@ def numpy_to_spark_type(numpy_type: Type) -> types.DataType:
raise ValueError("Unsupported NumPy type: " + str(numpy_type))


def python_to_spark_type(python_type: Union[int, float, bool, str, bytes]) -> types.DataType:
def python_to_spark_type(python_type: Type[Union[int, float, bool, str, bytes]]) -> types.DataType:
"""Function to convert a Python type to a Spark type.
:param python_type: the Python type to convert.
Expand All @@ -199,12 +200,18 @@ def python_to_spark_type(python_type: Union[int, float, bool, str, bytes]) -> ty
raise ValueError("Unsupported Python type: " + str(python_type))


if sys.version_info < (3, 9):
_list = (List[int], List[float], List[bool], List[str], List[bytes])
else:
_list = (list[int], list[float], list[bool], list[str], list[bytes])


def get_spark_type(
actual_kwargs: dict, df: DataFrame, hamilton_udf: Callable, return_type: Any
) -> types.DataType:
if return_type in (int, float, bool, str, bytes):
return python_to_spark_type(return_type)
elif return_type in (list[int], list[float], list[bool], list[str], list[bytes]):
elif return_type in _list:
return types.ArrayType(python_to_spark_type(return_type.__args__[0]))
elif hasattr(return_type, "__module__") and getattr(return_type, "__module__") == "numpy":
return numpy_to_spark_type(return_type)
Expand Down

0 comments on commit 1807e6d

Please sign in to comment.