Skip to content

Commit

Permalink
pre-commit post spark rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Aug 22, 2023
1 parent a1ac01d commit 5696e81
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 21 deletions.
28 changes: 8 additions & 20 deletions graph_adapter_tests/h_spark/test_h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import pyspark.pandas as ps
import pytest
from pyspark import Row
from pyspark.sql import Column, DataFrame, SparkSession
from pyspark.sql import types
from pyspark.sql import Column, DataFrame, SparkSession, types
from pyspark.sql.functions import column

from hamilton import base, driver, htypes, node
from hamilton.experimental import h_spark

from .resources import example_module, smoke_screen_module
from .resources.spark import (
basic_spark_dag,
Expand Down Expand Up @@ -304,10 +304,7 @@ def test_python_to_spark_type_invalid(invalid_python_type):
],
)
def test_get_spark_type_basic_types(return_type, expected_spark_type):
assert (
h_spark.get_spark_type(return_type)
== expected_spark_type
)
assert h_spark.get_spark_type(return_type) == expected_spark_type


# 2. Lists of basic Python types
Expand All @@ -322,14 +319,9 @@ def test_get_spark_type_basic_types(return_type, expected_spark_type):
(bytes, types.ArrayType(types.BinaryType())),
],
)
def test_get_spark_type_list_types(
return_type, expected_spark_type
):
def test_get_spark_type_list_types(return_type, expected_spark_type):
return_type = list[return_type] # type: ignore
assert (
h_spark.get_spark_type(return_type)
== expected_spark_type
)
assert h_spark.get_spark_type(return_type) == expected_spark_type


# 3. Numpy types (assuming you have a numpy_to_spark_type function that handles these)
Expand All @@ -341,13 +333,8 @@ def test_get_spark_type_list_types(
(np.bool_, types.BooleanType()),
],
)
def test_get_spark_type_numpy_types(
return_type, expected_spark_type
):
assert (
h_spark.get_spark_type(return_type)
== expected_spark_type
)
def test_get_spark_type_numpy_types(return_type, expected_spark_type):
assert h_spark.get_spark_type(return_type) == expected_spark_type


# 4. Unsupported types
Expand Down Expand Up @@ -379,6 +366,7 @@ def dummyfunc(x: int) -> int:

return dummyfunc


def test_base_spark_executor_end_to_end(spark_session):
# TODO -- make this simpler to call, and not require all these constructs
dr = (
Expand Down
1 change: 0 additions & 1 deletion hamilton/experimental/h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def python_to_spark_type(python_type: Type[Union[int, float, bool, str, bytes]])
_list = (list[int], list[float], list[bool], list[str], list[bytes])



def get_spark_type(return_type: Any) -> types.DataType:
if return_type in (int, float, bool, str, bytes):
return python_to_spark_type(return_type)
Expand Down

0 comments on commit 5696e81

Please sign in to comment.