diff --git a/graph_adapter_tests/h_spark/test_h_spark.py b/graph_adapter_tests/h_spark/test_h_spark.py index 3dec1706c..9ec4a0485 100644 --- a/graph_adapter_tests/h_spark/test_h_spark.py +++ b/graph_adapter_tests/h_spark/test_h_spark.py @@ -5,12 +5,12 @@ import pyspark.pandas as ps import pytest from pyspark import Row -from pyspark.sql import Column, DataFrame, SparkSession -from pyspark.sql import types +from pyspark.sql import Column, DataFrame, SparkSession, types from pyspark.sql.functions import column from hamilton import base, driver, htypes, node from hamilton.experimental import h_spark + from .resources import example_module, smoke_screen_module from .resources.spark import ( basic_spark_dag, @@ -304,10 +304,7 @@ def test_python_to_spark_type_invalid(invalid_python_type): ], ) def test_get_spark_type_basic_types(return_type, expected_spark_type): - assert ( - h_spark.get_spark_type(return_type) - == expected_spark_type - ) + assert h_spark.get_spark_type(return_type) == expected_spark_type # 2. Lists of basic Python types @@ -322,14 +319,9 @@ def test_get_spark_type_basic_types(return_type, expected_spark_type): (bytes, types.ArrayType(types.BinaryType())), ], ) -def test_get_spark_type_list_types( - return_type, expected_spark_type -): +def test_get_spark_type_list_types(return_type, expected_spark_type): return_type = list[return_type] # type: ignore - assert ( - h_spark.get_spark_type(return_type) - == expected_spark_type - ) + assert h_spark.get_spark_type(return_type) == expected_spark_type # 3. Numpy types (assuming you have a numpy_to_spark_type function that handles these) @@ -341,13 +333,8 @@ def test_get_spark_type_list_types( (np.bool_, types.BooleanType()), ], ) -def test_get_spark_type_numpy_types( - return_type, expected_spark_type -): - assert ( - h_spark.get_spark_type(return_type) - == expected_spark_type - ) +def test_get_spark_type_numpy_types(return_type, expected_spark_type): + assert h_spark.get_spark_type(return_type) == expected_spark_type # 4. Unsupported types @@ -379,6 +366,7 @@ def dummyfunc(x: int) -> int: return dummyfunc + def test_base_spark_executor_end_to_end(spark_session): # TODO -- make this simpler to call, and not require all these constructs dr = ( diff --git a/hamilton/experimental/h_spark.py b/hamilton/experimental/h_spark.py index c98723746..c5ada3073 100644 --- a/hamilton/experimental/h_spark.py +++ b/hamilton/experimental/h_spark.py @@ -215,11 +215,10 @@ def python_to_spark_type(python_type: Type[Union[int, float, bool, str, bytes]]) _list = (list[int], list[float], list[bool], list[str], list[bytes]) - def get_spark_type(return_type: Any) -> types.DataType: if return_type in (int, float, bool, str, bytes): return python_to_spark_type(return_type) - elif return_type in (list[int], list[float], list[bool], list[str], list[bytes]): + elif return_type in _list: return types.ArrayType(python_to_spark_type(return_type.__args__[0])) elif return_type in _list: return types.ArrayType(python_to_spark_type(return_type.__args__[0]))