diff --git a/.circleci/config.yml b/.circleci/config.yml index 46f6b4e07..f573dbbcc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -55,8 +55,16 @@ workflows: python-version: '3.9' task: ray - test: - name: spark-py38 - python-version: '3.8' + name: spark-py39 + python-version: '3.9' + task: pyspark + - test: + name: spark-py310 + python-version: '3.10' + task: pyspark + - test: + name: spark-py311 + python-version: '3.11' task: pyspark - test: name: integrations-py37 diff --git a/docs/how-tos/scale-up.rst b/docs/how-tos/scale-up.rst index 6f237a9f4..ebf26fab8 100644 --- a/docs/how-tos/scale-up.rst +++ b/docs/how-tos/scale-up.rst @@ -10,4 +10,4 @@ on larger, distributed datasets (pandas on spark, pyspark map UDFs). 1. Integrating hamilton with `pandas on spark `_. 2. Integrating hamilton with `ray `_. 3. Integrating hamilton with `dask `_. -4. Integrating hamilton using `pyspark map UDFs `__. +4. Integrating hamilton with `pyspark `_. diff --git a/examples/spark/README.md b/examples/spark/README.md index da5684881..df7a588a5 100644 --- a/examples/spark/README.md +++ b/examples/spark/README.md @@ -1,4 +1,11 @@ + # Scaling Hamilton on Spark +## Pyspark + +If you're using pyspark, Hamilton allows for natural manipulation of pyspark dataframes, +with some special constructs for managing DAGs of UDFs. + +See the example in `pyspark` to learn more. ## Pandas If you're using Pandas, Hamilton scales by using Koalas on Spark. diff --git a/examples/spark/pyspark/README.md b/examples/spark/pyspark/README.md new file mode 100644 index 000000000..483905936 --- /dev/null +++ b/examples/spark/pyspark/README.md @@ -0,0 +1,324 @@ +# Hamilton and Pyspark + +**TL;DR** Hamilton now supports full pyspark integration. This enables you to write a DAG of transformations as +python UDFs, pandas UDFs, or pyspark transformations and apply them to a central dataframe, using the +[Hamilton](https://github.com/dagworks-inc/hamilton) paradigm. Functions written this way improve maintaibility, modularity, +readability, and clarity of data lineage in spark ETLs. + +

+ +

+ +*A spark pipeline representing multiple joins (in blue), a set of map operations (in green) and a set of join/filters (in yellow). This uses Hamilton’s visualization features (with a little extra annotation). See [TPC-H query 8](../tpc-h/query_8.py) for motivation.* + +## Apache Spark +[Apache Spark](https://spark.apache.org/) (and its python API, [pyspark](https://spark.apache.org/docs/latest/api/python/index.html)) is an open-source library for building out highly scalable data transformations. +At its core is the notion of the RDD (resilient distributed dataframe), which represents a lazily evaluated, +partitioned, in-memory dataset that stores the information needed to recreate the data +if any of the servers computing it fail. The pyspark library gives data practitioners +a dataframe-centric API to interact with this in python, enabling them to specify computation +and scale up to the resources they have available. Since its introduction in 2014, spark has taken +off and is now the de facto way to perform computations on large (multi gb -> multi tb) datasets. + +## Limitations of Spark for Complex Pipelines + +Just like any ETLs, spark pipelines can be difficult to maintain and manage, +and often devolve into spaghetti code over time. +Specifically, we've observed the following problems with pyspark pipelines: + +1. _They rarely get broken up into modular and reusable components._ +2. _They commonly contain "implicit" dependencies._ Even when you do break them into functions, it is difficult to specify which columns the transformed dataframes depend on, and how that changes throughout your workflow. +3. _They are difficult to configure in a readable manner._ A monolithic spark script likely has a few different shapes/parameters, and naturally becomes littered with poorly documented if/else statements. +4. _They are not easy to unit test._ While specific UDFs can be tested, spark transformations are tough to test in a modular fashion. +5. _They are notoriously tricky to debug._ Large pipelines of spark transformations (much like SQL transformations) will often have errors that cascade upwards, and pinpointing the source of these can be quite a challenge. + + +# Hamilton +As this is a README inside the Hamilton repository, we assume some basic familiarity. That said, here's a quick primer: + +Hamilton is an open-source Python framework for writing data transformations. +One writes Python functions in a declarative style, which Hamilton parses into nodes in +a graph based on their names, arguments and type annotations. The simple rule is akin to that of pytest fixtures -- +the name of a parameter points to another node (function) in the graph, and the name of the function defines a referencable node. +You can request specific outputs, and Hamilton will execute the required nodes (specified by your functions) to produce them. + +You can try hamilton out in your browser at [tryhamilton.dev](https://tryhamilton.dev). + +# Integration + +Breaking your pipeline into Hamilton functions with pyspark dataframes as inputs and outputs gets you most of +the way towards more modular/documented code. +That said, it falls flat in a critical area – column-level lineage/transformation +simplicity. For complex series of map operations, spark represents all transformations +on a single dataframe in a linear chain by repeatedly calling `withColumn`/`select` to create columns. +For dataframe APIs that manage indices, hamilton improves this experience by encouraging the user to +pull apart column-level transformations then join later. With columns that share cardinality, this is generally an efficient approach. + +Spark, however, has no notion of indices. Data is partitioned across a cluster, and once a set of columns is selected it has the potential to be reshuffled. +Thus, the two options one previously had for integrating with pyspark both have disadvantages: + +1. Extracting into columns then joining is prohibitively expensive and taxing on the spark optimizer (which we have not found was smart enough to detect this pattern) +2. Running pure DataFrame transformations does not afford the expressiveness that Hamilton provides. + +The idea is to break your code into components. These components make one of two shapes: + +1. Run linearly (e.g. cardinality non-preserving operations: aggregations, filters, joins, etc..) +2. Form a DAG of column-level operations (for cardinality-preserving operations) + +For the first case, we just use the pyspark dataframe API. You define functions that, when put +through Hamilton, act as a pipe. For example: + +Hamilton `1.27.0` introduces a new API to give the user the best of both worlds. You can now express column-level +operations in a DAG on the same dataframe, as part of a multi-step process. + +With the new `@with_columns` decorator, your break your pipeline into two classes of steps: + +### Joins/Aggregations/Filters + +We simply write functions that take in dataframes and return dataframes. +```python +import pyspark.sql as ps + +def raw_data_1() -> ps.DataFrame: + """Loads up data from an external source""" + +def raw_data_2() -> ps.DataFrame: + """Loads up data from an external source""" + +def all_initial_data(raw_data_1: ps.DataFrame, raw_data_2: ps.DataFrame) -> ps.DataFrame: + """Combines the two dataframes""" + return _join(raw_data_1, raw_data_2) + +def raw_data_3() -> ps.DataFrame: + """Loads up data from an external source""" +``` + +### Columnar Operations + +For the next case, we define transformations that are columnar/map-oriented in nature. +These are UDFs (either pandas or python), or functions of pyspark constructs, that get applied +to the upstream dataframe in a specific order: + +```python +import pandas as pd + +#map_transforms.py + +def column_3(column_1_from_dataframe: pd.Series) -> pd.Series: + return _some_transform(column_1_from_dataframe) + +def column_4(column_2_from_dataframe: pd.Series) -> pd.Series: + return _some_other_transform(column_2_from_dataframe) + +def column_5(column_3: pd.Series, column_4: pd.Series) -> pd.Series: + return _yet_another_transform(column_3, column_4) +``` + +Finally, we combine them together with a call to `with_column`: + +```python +from hamilton.experimental.h_spark import with_columns +import pyspark.sql as ps +import map_transforms # file defined above + +@with_columns( + map_transforms, # Load all the functions we defined above + columns_to_pass=[ + "column_1_from_dataframe", + "column_2_from_dataframe", + "column_3_from_dataframe"], # use these from the initial datafrmae +) +def final_result(all_initial_data: ps.DataFrame, raw_data_3: ps.DataFrame) -> ps.DataFrame: + """Gives the final result. This decorator will apply the transformations in the order. + Then, the final_result function is called, with the result of the transformations passed in.""" + return _join(all_initial_data, raw_data_3) +``` +Contained within the `load_from` functions/modules is a set of transformations that specify a DAG. +These transformations can take multiple forms – they can use vanilla pyspark operations, pandas UDFs, +or standard python UDFs. See documentation for specific examples. + +The easiest way to think about this is that the `with_columns` decorator “linearizes” the DAG. +It turns a DAG of hamilton functions into a linear chain, repeatedly appending those columns to the initial dataframe. +![](illustration.png | width=500px) +*The natural DAG of steps in three separate configurations -- hamilton/pandas, pure pyspark, and pyspark + hamilton* + +`with_columns` takes in the following parameters (see the docstring for more info) +1. `load_from` -- a list of functions/modules to find the functions to load the DAG from, similar to `@subdag` +2. `columns_to_pass` -- not compatible with `pass_dataframe_as`. Dependencies specified from the initial dataframe, +injected in. Not that you must use one of this or `pass_dataframe_as` +3. `pass_dtaframe_as` -- the name of the parameter to inject the initial dataframe into the subdag. +If this is provided, this must be the only pyspark dataframe dependency in the subdag that is not also another +node (column) in the subdag. +4. `select` -- a list of columns to select from the UDF group. If not specified all will be selected. +5. `namespace` -- the namespace of the nodes generated by this -- will default to the function name that is decorated. + +Note that the dependency that forms the core dataframe will always be the first parameter to the function. Therefore, the first parameter +must be a pyspark dataframe and share the name of an upstream node that returns a pyspark dataframe. + +You have two options when presenting the initial dataframe/how to read it. Each corresponds to a `with_columns` parameter. You can use: +1.`columns_to_pass` to constrain the columns that must exist in the initial dataframe, which you refer to in your functions. In the example above, the functions can refer to the three columns `column_1_from_dataframe`, `column_2_from_dataframe`, and `column_3_from_dataframe`, but those cannot be named defined by the subdag. +2. `pass_dataframe_as` to pass the dataframe you're transforming in as a specific parameter name to the subdag. This allows you to handle the extraction -- use this if you want to redefine columns in the dataframe/preserve the same names. + +```python +import pandas as pd, pyspark.sql as ps + +#map_transforms.py + +def colums_1_from_dataframe(input_dataframe: ps.DataFrame) -> ps.Column: + return input_dataframe.column_1_from_dataframe + +def column_2_from_dataframe(input_dataframe: ps.DataFrame) -> ps.Column: + return input_dataframe.column_2_from_dataframe + +def column_3(column_1_from_dataframe: pd.Series) -> pd.Series: + return _some_transform(column_1_from_dataframe) + +def column_4(column_2_from_dataframe: pd.Series) -> pd.Series: + return _some_other_transform(column_2_from_dataframe) + +def column_5(column_3: pd.Series, column_4: pd.Series) -> pd.Series: + return _yet_another_transform(column_3, column_4) +``` + +```python +from hamilton.experimental.h_spark import with_columns +import pyspark.sql as ps +import map_transforms # file defined above + +@with_columns( + map_transforms, # Load all the functions we defined above + pass_dataframe_as="input_dataframe", #the upstream dataframe, referred to by downstream nodes, will have this parametter name +) +def final_result(all_initial_data: ps.DataFrame, raw_data_3: ps.DataFrame) -> ps.DataFrame: + """Gives the final result. This decorator will apply the transformations in the order. + Then, the final_result function is called, with the result of the transformations passed in.""" + return _join(all_initial_data, raw_data_3) +``` + +Approach (2) requires functions that take in pyspark dataframes and return pyspark dataframes or columns for the functions reading directly from the dataframe. +If you want to stay in pandas entirely for the `with_columns` group, you should use approach (1). + +### Flavors of UDFs + +There are four flavors of transforms supported that compose the group of DAG transformations: + +#### Pandas -> Pandas UDFs +These are functions of series: + +```python +from hamilton import htypes + +def foo(bar: pd.Series, baz: pd.Series) -> htypes.column[pd.Series, int]: + return bar + 1 +``` + +The rules are the same as vanilla hamilton -- the parameter name determines the upstream dependencies, +and the function name determines the output column name. + +Note that, due to the type-specification requirements of pyspark, these have to return a "typed" (`Annotated[]`) series, specified by `htypes.column`. These are adapted +to form pyspark-friendly [pandas UDFs](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.functions.pandas_udf.html) + +#### Python primitives -> Python Primitives UDFs + +These are functions of python primitives: + +```python +def foo(bar: int, baz: int) -> int: + return bar + 1 +``` + +These are adapted to standard [pyspark UDFs](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.udf.html). + +#### Pyspark Dataframe -> Pyspark Columns + +These are functions that take in a pyspark dataframe (single) and output a pyspark column. + +```python +def foo(bar: ps.DataFrame) -> ps.Column: + return df["bar"] + 1 +``` + +Note that these have two forms: +1. The dataframe specifies the name of the upstream column -- then you just access the column and return a manipulation +2. The dataframe contains more than one column, in which case you need the `@require(...)` decorator, to specify which column you want to use. + +```python +import h_spark + + +@h_spark.require_columns("bar", "baz") +def foo(bar_baz: ps.DataFrame) -> ps.Column: + return df["bar"] + 1 +``` + +In this case we are only allowed a single dataframe dependency, and the parameter name does not matter. +The columns specified are injected into the dataframe, allowing you to depend on multiple upstream columns. + +#### pyspark dataframe -> pyspark dataframe + +This is the ultimate power-user case, where you can manipulate the dataframe in any way you want. +Note that this and the column-flavor is an _out_, meaning that its a way to jump back to the pyspark world and not have to break up +your map functions for a windowed aggregation. You can easily shoot yourself in the foot here. This should only be used if +you strongly feel the need to inject a map-like (index-preserving, but not row-wise) operation into the DAG, +and the df -> column flavor is not sufficient (and if you find yourself using this a lot, please reach +out, we'd love to hear your use-case). + +This has the exact same rules as the column flavor, except that the return type is a dataframe. + +```python +import h_spark + + +@h_spark.require_columns("bar", "baz") +def foo(df: ps.DataFrame) -> ps.DataFrame: + return df.withColumn("bar", df["bar"] + 1) +``` + +Note that this is isomorphic to the column-flavor in which you (not the framework) are responsible for calling `withColumn`. + +We have implemented the hamilton hello_world example in [run.py](run.py) and the [map_transforms.py](map_transforms.py)/[dataflow.py](dataflow.py) files +so you can compare. You can run `run.py`: + +`python run.py` + +and check out the interactive example in the `notebook.ipynb` file. + +We have also implemented three of the [TPC-H](https://www.tpc.org/tpch/) query functions to demonstrate a more real-world set of queries: +1. [query_1][../tpc-h/query_1.py] +2. [query_8][../tpc-h/query_8.py] +3. [query_12][../tpc-h/query_12.py] + +See the [README](../tpc-h/README.md) for more details on how to run these. + + +## Technical Details + +The `with_columns` decorator does the following: +1. Resolves the functions you pass in, with the config passed from the driver +2. Transforms them each into a node, in topological order. + - Retains all specified dependencies + - Adds a single dataframe that gets wired through (linearizing the operations) + - Transforms each function into a function of that input dataframe and any other external dependencies + +Thus the graph continually assigns to a single (immutable) dataframe, tracking the result, and still displays the DAG shape +that was presented by the code. Column-level lineage is preserved as dependencies and easy to read from the code, while it executes as a +normal set of spark operations. + +## Scaling Alternatives + +Pyspark is not the only way to scale up your computation. Hamilton supports `pandas-on-spark` as well. You can use pandas-on-spark with the `KoalaGraphAdapter` -- see [Pandas on Spark](../pandas_on_spark/README.md) for reference. +Some people prefer vanilla spark, some like pandas-on-spark. We support both. Hamilton also support executing map-based pandas UDFs in pyspark, in case you want simple parallelism. See [pyspark_udfs](../pyspark_udfs/README.md) for reference. + +Hamilton has integrations with other scaling libraries as well -- it all depends on your use-case: + +- [dask](../../dask/README.md) +- [ray](../../ray/README.md) +- [modin](https://github.com/modin-project/modin) (no example for modin yet but it is just the pandas API with a different import) + +## Next Steps +A few interesting directions: +1. Improve the visualization, allowing you to differentiate the dependencies that just exist for structure from the central linear dependency that flows through. +2. Add similar capabilities for other dataframe libraries for which series/indices are not first-class citizens (polars, etc...) +3. Add data quality decorations for spark -- this is complex as they often require realization of the DAG, which one typically wishes to delay along with the rest of the computation +4. Add more data loaders to seamlessly load data to spark/from spark +5. Add natural constructs for `collect()`/`cache()` through the DAG diff --git a/examples/spark/pyspark/dataflow.py b/examples/spark/pyspark/dataflow.py new file mode 100644 index 000000000..01efbde6c --- /dev/null +++ b/examples/spark/pyspark/dataflow.py @@ -0,0 +1,124 @@ +from typing import Dict + +import map_transforms +import pandas as pd +import pyspark.sql as ps +from pyspark.sql.functions import col, mean, stddev + +from hamilton.experimental import h_spark +from hamilton.function_modifiers import extract_fields + + +def spark_session() -> ps.SparkSession: + """Pyspark session to load up when starting. + You can also pass it in if you so choose. + + :return: + """ + return ps.SparkSession.builder.master("local[1]").getOrCreate() + + +def base_df(spark_session: ps.SparkSession) -> ps.DataFrame: + """Dummy function showing how to wire through loading data. + Note you can use @load_from (although our spark data loaders are limited now). + + :return: A dataframe with spend and signups columns. + """ + pd_df = pd.DataFrame( + { + "spend": [ + 10, + 10, + 20, + 40, + 40, + 50, + 60, + 70, + 90, + 100, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + ], + "signups": [ + 1, + 10, + 50, + 100, + 200, + 400, + 600, + 800, + 1000, + 1200, + 1400, + 1600, + 1800, + 2000, + 2200, + 2400, + 2600, + 2800, + 3000, + 3200, + ], + } + ) + return spark_session.createDataFrame(pd_df) + + +@extract_fields( + { + "spend_mean": float, + "spend_std_dev": float, + } +) +def spend_statistics(base_df: ps.DataFrame) -> Dict[str, float]: + """Computes the mean and standard deviation of the spend column. + Note that this is a blocking (collect) operation, + but it doesn't have to be if you use an aggregation. In that case + you'd just add the column to the dataframe and refer to it downstream, + by expanding `columns_to_pass` in `with_mapped_data`. + + :param base_df: Base dataframe with spend and signups columns. + :return: A dictionary with the mean and standard deviation of the spend column. + """ + df_stats = base_df.select( + mean(col("spend")).alias("mean"), stddev(col("spend")).alias("std") + ).collect() + + return { + "spend_mean": df_stats[0]["mean"], + "spend_std_dev": df_stats[0]["std"], + } + + +@h_spark.with_columns( + map_transforms, + columns_to_pass=["spend", "signups"], +) +def with_mapped_data(base_df: ps.DataFrame) -> ps.DataFrame: + """Applies all the transforms in map_transforms + + :param base_df: + :return: + """ + return base_df + + +def final_result(with_mapped_data: ps.DataFrame) -> pd.DataFrame: + """Computes the final result. You could always just output the pyspark + dataframe, but we'll collect it and make it a pandas dataframe. + + :param base_df: Base dataframe with spend and signups columns. + :return: A dataframe with the final result. + """ + return with_mapped_data.toPandas() diff --git a/examples/spark/pyspark/grouped_transformations.png b/examples/spark/pyspark/grouped_transformations.png new file mode 100644 index 000000000..f2ade2850 Binary files /dev/null and b/examples/spark/pyspark/grouped_transformations.png differ diff --git a/examples/spark/pyspark/illustration.png b/examples/spark/pyspark/illustration.png new file mode 100644 index 000000000..724fe7fa3 Binary files /dev/null and b/examples/spark/pyspark/illustration.png differ diff --git a/examples/spark/pyspark/map_transforms.py b/examples/spark/pyspark/map_transforms.py new file mode 100644 index 000000000..876268414 --- /dev/null +++ b/examples/spark/pyspark/map_transforms.py @@ -0,0 +1,20 @@ +import pandas as pd + +from hamilton.htypes import column + + +def spend_per_signup(spend: pd.Series, signups: pd.Series) -> column[pd.Series, float]: + """The cost per signup in relation to spend.""" + return spend / signups + + +def spend_zero_mean(spend: pd.Series, spend_mean: float) -> column[pd.Series, float]: + """Shows function that takes a scalar. In this case to zero mean spend.""" + return spend - spend_mean + + +def spend_zero_mean_unit_variance( + spend_zero_mean: pd.Series, spend_std_dev: float +) -> column[pd.Series, float]: + """Function showing one way to make spend have zero mean and unit variance.""" + return spend_zero_mean / spend_std_dev diff --git a/examples/spark/pyspark/notebook.ipynb b/examples/spark/pyspark/notebook.ipynb new file mode 100644 index 000000000..e7f704ed1 --- /dev/null +++ b/examples/spark/pyspark/notebook.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 13, + "id": "4c8c7cb7", + "metadata": {}, + "outputs": [], + "source": [ + "import pyspark.sql as ps\n", + "import pandas as pd\n", + "from pyspark.sql.functions import col, mean, stddev" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a85bb3cf", + "metadata": {}, + "outputs": [], + "source": [ + "spark_session = ps.SparkSession.builder.master(\"local[1]\").getOrCreate()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ac02f09c", + "metadata": {}, + "outputs": [], + "source": [ + "pd_df = pd.DataFrame(\n", + " {\n", + " \"spend\": [\n", + " 10,\n", + " 10,\n", + " 20,\n", + " 40,\n", + " 40,\n", + " 50,\n", + " 60,\n", + " 70,\n", + " 90,\n", + " 100,\n", + " 70,\n", + " 80,\n", + " 90,\n", + " 100,\n", + " 110,\n", + " 120,\n", + " 130,\n", + " 140,\n", + " 150,\n", + " 160,\n", + " ],\n", + " \"signups\": [\n", + " 1,\n", + " 10,\n", + " 50,\n", + " 100,\n", + " 200,\n", + " 400,\n", + " 600,\n", + " 800,\n", + " 1000,\n", + " 1200,\n", + " 1400,\n", + " 1600,\n", + " 1800,\n", + " 2000,\n", + " 2200,\n", + " 2400,\n", + " 2600,\n", + " 2800,\n", + " 3000,\n", + " 3200,\n", + " ],\n", + " }\n", + " )\n", + "ps_df = spark_session.createDataFrame(pd_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "71fd52ed", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataFrame[foo: double]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ps_df.select(mean(col(\"spend\")).alias(\"foo\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "8986d1ca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataFrame[spend: bigint, signups: bigint, foo: bigint]" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ps_df.withColumn(\"foo\", ps_df['signups']*ps_df['spend'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7489e4dd", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/spark/pyspark/out.png b/examples/spark/pyspark/out.png new file mode 100644 index 000000000..e86907761 --- /dev/null +++ b/examples/spark/pyspark/out.png @@ -0,0 +1,25 @@ +// Dependency Graph +digraph { + base_df [label=base_df] + spend_statistics [label=spend_statistics] + spark_session [label=spark_session] + spend_mean [label=spend_mean] + spend_std_dev [label=spend_std_dev] + "with_mapped_data.spend_per_signup" [label="with_mapped_data.spend_per_signup"] + with_mapped_data [label=with_mapped_data] + "with_mapped_data.spend_zero_mean" [label="with_mapped_data.spend_zero_mean"] + final_result [label=final_result shape=rectangle] + "with_mapped_data.spend_zero_mean_unit_variance" [label="with_mapped_data.spend_zero_mean_unit_variance"] + spark_session -> base_df + base_df -> spend_statistics + spend_statistics -> spend_mean + spend_statistics -> spend_std_dev + base_df -> "with_mapped_data.spend_per_signup" + "with_mapped_data.spend_zero_mean_unit_variance" -> with_mapped_data + spend_mean -> "with_mapped_data.spend_zero_mean" + "with_mapped_data.spend_per_signup" -> "with_mapped_data.spend_zero_mean" + base_df -> "with_mapped_data.spend_zero_mean" + with_mapped_data -> final_result + "with_mapped_data.spend_zero_mean" -> "with_mapped_data.spend_zero_mean_unit_variance" + spend_std_dev -> "with_mapped_data.spend_zero_mean_unit_variance" +} diff --git a/examples/spark/pyspark/run.py b/examples/spark/pyspark/run.py new file mode 100644 index 000000000..227f8a5a0 --- /dev/null +++ b/examples/spark/pyspark/run.py @@ -0,0 +1,15 @@ +import dataflow +import map_transforms + +from hamilton import driver + + +def main(): + dr = driver.Builder().with_modules(dataflow, map_transforms).build() + dr.visualize_execution(["final_result"], "./out.png", {"format": "png"}) + final_result = dr.execute(["final_result"]) + print(final_result) + + +if __name__ == "__main__": + main() diff --git a/examples/spark/pyspark_udfs/my_spark_udf.dot.png b/examples/spark/pyspark_udfs/my_spark_udf.dot.png index 35f92814c..985947a8d 100644 Binary files a/examples/spark/pyspark_udfs/my_spark_udf.dot.png and b/examples/spark/pyspark_udfs/my_spark_udf.dot.png differ diff --git a/examples/spark/tpc-h/README.md b/examples/spark/tpc-h/README.md new file mode 100644 index 000000000..18e348f2b --- /dev/null +++ b/examples/spark/tpc-h/README.md @@ -0,0 +1,12 @@ +# TPC-H + +We've represented a few TPC-h queries using pyspark + hamilton. + +While we have not optimized these for benchmarking, they provide a good set of examples for how to express pyspark logic/break +it into hamilton functions. + +## Running + +To run, you have `run.py` -- this enables you to run a few of the queries. That said, you'll have to generate the data on your own, which is a bit tricky. + +Download dbgen here, and follow the instructions: https://www.tpc.org/tpch/. You can also reach out to us and we'll help you get set up. diff --git a/examples/spark/tpc-h/csv_data_loaders.py b/examples/spark/tpc-h/csv_data_loaders.py new file mode 100644 index 000000000..3aca0f042 --- /dev/null +++ b/examples/spark/tpc-h/csv_data_loaders.py @@ -0,0 +1,62 @@ +import os + +import pyspark.sql as ps + +from hamilton.function_modifiers import load_from, parameterize, source, value + + +@parameterize( + customer_path={"suffix": value("customer.tbl")}, + lineitem_path={"suffix": value("lineitem.tbl")}, + nation_path={"suffix": value("nation.tbl")}, + orders_path={"suffix": value("orders.tbl")}, + part_path={"suffix": value("part.tbl")}, + partsupp_path={"suffix": value("partsupp.tbl")}, + region_path={"suffix": value("region.tbl")}, + supplier_path={"suffix": value("supplier.tbl")}, +) +def paths(suffix: str, data_dir: str) -> str: + return os.path.join(data_dir, suffix) + + +@load_from.csv(path=source("customer_path"), sep=value("|"), spark=source("spark")) +def customer(df: ps.DataFrame) -> ps.DataFrame: + return df + + +# TODO -- parameterize these, but this is fine for now + + +@load_from.csv(path=source("lineitem_path"), sep=value("|"), spark=source("spark")) +def lineitem(df: ps.DataFrame) -> ps.DataFrame: + return df + + +@load_from.csv(path=source("nation_path"), sep=value("|"), spark=source("spark")) +def nation(df: ps.DataFrame) -> ps.DataFrame: + return df + + +@load_from.csv(path=source("orders_path"), sep=value("|"), spark=source("spark")) +def orders(df: ps.DataFrame) -> ps.DataFrame: + return df + + +@load_from.csv(path=source("part_path"), sep=value("|"), spark=source("spark")) +def part(df: ps.DataFrame) -> ps.DataFrame: + return df + + +@load_from.csv(path=source("partsupp_path"), sep=value("|"), spark=source("spark")) +def partsupp(df: ps.DataFrame) -> ps.DataFrame: + return df + + +@load_from.csv(path=source("region_path"), sep=value("|"), spark=source("spark")) +def region(df: ps.DataFrame) -> ps.DataFrame: + return df + + +@load_from.csv(path=source("supplier_path"), sep=value("|"), spark=source("spark")) +def supplier(df: ps.DataFrame) -> ps.DataFrame: + return df diff --git a/examples/spark/tpc-h/dag-query_1.pdf b/examples/spark/tpc-h/dag-query_1.pdf new file mode 100644 index 000000000..93c6d0c55 Binary files /dev/null and b/examples/spark/tpc-h/dag-query_1.pdf differ diff --git a/examples/spark/tpc-h/dag-query_12.pdf b/examples/spark/tpc-h/dag-query_12.pdf new file mode 100644 index 000000000..ecfa37628 Binary files /dev/null and b/examples/spark/tpc-h/dag-query_12.pdf differ diff --git a/examples/spark/tpc-h/dag-query_8.pdf b/examples/spark/tpc-h/dag-query_8.pdf new file mode 100644 index 000000000..221b283a6 Binary files /dev/null and b/examples/spark/tpc-h/dag-query_8.pdf differ diff --git a/examples/spark/tpc-h/query_1.py b/examples/spark/tpc-h/query_1.py new file mode 100644 index 000000000..25c178df0 --- /dev/null +++ b/examples/spark/tpc-h/query_1.py @@ -0,0 +1,56 @@ +import datetime + +import pandas as pd +import pyspark.sql as ps +from pyspark.sql import functions as F + +from hamilton import htypes +from hamilton.experimental import h_spark + + +# See https://github.com/dragansah/tpch-dbgen/blob/master/tpch-queries/1.sql +def start_date() -> str: + return (datetime.date(1998, 12, 1) - datetime.timedelta(days=90)).format("YYYY-MM-DD") + + +def lineitem_filtered(lineitem: ps.DataFrame, start_date: str) -> ps.DataFrame: + return lineitem.filter((lineitem.l_shipdate <= start_date)) + + +def disc_price( + l_extendedprice: pd.Series, l_discount: pd.Series +) -> htypes.column[pd.Series, float]: + return l_extendedprice * (1 - l_discount) + + +def charge( + l_extendedprice: pd.Series, l_discount: pd.Series, l_tax: pd.Series +) -> htypes.column[pd.Series, float]: + # This could easily depend on the prior one... + return l_extendedprice * (1 - l_discount) * (1 + l_tax) + + +@h_spark.with_columns( + disc_price, + charge, + columns_to_pass=["l_extendedprice", "l_discount", "l_tax"], +) +def lineitem_price_charge(lineitem: ps.DataFrame) -> ps.DataFrame: + return lineitem + + +def final_data(lineitem_price_charge: ps.DataFrame) -> ps.DataFrame: + return ( + lineitem_price_charge.groupBy(["l_returnflag", "l_linestatus"]) + .agg( + F.sum("l_quantity").alias("sum_l_quantity"), + F.avg("l_quantity").alias("avg_l_quantity"), + F.sum("l_extendedprice").alias("sum_l_extendedprice"), + F.avg("l_extendedprice").alias("avg_l_extendedprice"), + F.sum("disc_price").alias("sum_disc_price"), + F.sum("charge").alias("sum_charge"), + F.avg("l_discount").alias("avg_l_discount"), + F.count("*").alias("count"), + ) + .orderBy(["l_returnflag", "l_linestatus"]) + ).toPandas() diff --git a/examples/spark/tpc-h/query_12.py b/examples/spark/tpc-h/query_12.py new file mode 100644 index 000000000..903566e0f --- /dev/null +++ b/examples/spark/tpc-h/query_12.py @@ -0,0 +1,56 @@ +import pandas as pd +import pyspark.sql as ps +from pyspark.sql import functions as F + +from hamilton import htypes +from hamilton.experimental import h_spark + +# see # See # See https://github.com/dragansah/tpch-dbgen/blob/master/tpch-queries/12.sql + + +def lineitems_joined_with_orders(lineitem: ps.DataFrame, orders: ps.DataFrame) -> ps.DataFrame: + return lineitem.join(orders, lineitem.l_orderkey == orders.o_orderkey) + + +def start_date() -> str: + return "1995-01-01" + + +def end_date() -> str: + return "1996-12-31" + + +def filtered_data( + lineitems_joined_with_orders: ps.DataFrame, start_date: str, end_date: str +) -> ps.DataFrame: + return lineitems_joined_with_orders.filter( + (lineitems_joined_with_orders.l_shipmode.isin("MAIL", "SHIP")) + & (lineitems_joined_with_orders.l_commitdate < lineitems_joined_with_orders.l_receiptdate) + & (lineitems_joined_with_orders.l_shipdate < lineitems_joined_with_orders.l_commitdate) + & (lineitems_joined_with_orders.l_receiptdate >= start_date) + & (lineitems_joined_with_orders.l_receiptdate < end_date) + ) + + +def high_priority(o_orderpriority: pd.Series) -> htypes.column[pd.Series, int]: + return (o_orderpriority == "1-URGENT") | (o_orderpriority == "2-HIGH") + + +def low_priority(o_orderpriority: pd.Series) -> htypes.column[pd.Series, int]: + return (o_orderpriority != "1-URGENT") & (o_orderpriority != "2-HIGH") + + +@h_spark.with_columns(high_priority, low_priority, columns_to_pass=["o_orderpriority"]) +def with_priorities(filtered_data: ps.DataFrame) -> ps.DataFrame: + return filtered_data + + +def shipmode_aggregated(with_priorities: ps.DataFrame) -> ps.DataFrame: + return with_priorities.groupBy("l_shipmode").agg( + F.sum("high_priority").alias("sum_high"), + F.sum("low_priority").alias("sum_low"), + ) + + +def final_data(shipmode_aggregated: ps.DataFrame) -> pd.DataFrame: + return shipmode_aggregated.toPandas() diff --git a/examples/spark/tpc-h/query_8.py b/examples/spark/tpc-h/query_8.py new file mode 100644 index 000000000..4daaa13f9 --- /dev/null +++ b/examples/spark/tpc-h/query_8.py @@ -0,0 +1,91 @@ +import pandas as pd +import pyspark.sql as ps +import pyspark.sql.functions as F + +# See # See https://github.com/dragansah/tpch-dbgen/blob/master/tpch-queries/8.sql +from hamilton import htypes +from hamilton.experimental import h_spark + + +def start_date() -> str: + return "1995-01-01" + + +def end_date() -> str: + return "1996-12-31" + + +def america(region: ps.DataFrame) -> ps.DataFrame: + return region.filter(F.col("r_name") == "AMERICA") + + +def american_nations(nation: ps.DataFrame, america: ps.DataFrame) -> ps.DataFrame: + return nation.join(america, nation.n_regionkey == america.r_regionkey).select(["n_nationkey"]) + + +def american_customers(customer: ps.DataFrame, american_nations: ps.DataFrame) -> ps.DataFrame: + return customer.join(american_nations, customer.c_nationkey == american_nations.n_nationkey) + + +def american_orders(orders: ps.DataFrame, american_customers: ps.DataFrame) -> ps.DataFrame: + return orders.join(american_customers, orders.o_custkey == american_customers.c_custkey) + + +def order_data_augmented( + american_orders: ps.DataFrame, + lineitem: ps.DataFrame, + supplier: ps.DataFrame, + nation: ps.DataFrame, + part: ps.DataFrame, +) -> ps.DataFrame: + d = lineitem.join(part, lineitem.l_partkey == part.p_partkey).drop("n_nation", "n_nationkey") + d = d.join(american_orders.drop("n_nationkey"), d.l_orderkey == american_orders.o_orderkey) + d = d.join(supplier, d.l_suppkey == supplier.s_suppkey) + d = d.join(nation, d.s_nationkey == nation.n_nationkey) + return d + + +def order_data_filtered( + order_data_augmented: ps.DataFrame, + start_date: str, + end_date: str, + p_type: str = "ECONOMY ANODIZED STEEL", +) -> ps.DataFrame: + return order_data_augmented.filter( + (F.col("o_orderdate") >= F.to_date(F.lit(start_date))) + & (F.col("o_orderdate") <= F.to_date(F.lit(end_date))) + & (F.col("p_type") == p_type) + ) + + +def o_year(o_orderdate: pd.Series) -> htypes.column[pd.Series, int]: + return pd.to_datetime(o_orderdate).dt.year + + +def volume(l_extendedprice: pd.Series, l_discount: pd.Series) -> htypes.column[pd.Series, float]: + return l_extendedprice * (1 - l_discount) + + +def brazil_volume(n_name: pd.Series, volume: pd.Series) -> htypes.column[pd.Series, float]: + return volume.where(n_name == "BRAZIL", 0) + + +@h_spark.with_columns( + o_year, + volume, + brazil_volume, + columns_to_pass=["o_orderdate", "l_extendedprice", "l_discount", "n_name", "volume"], + select=["o_year", "volume", "brazil_volume"], +) +def processed(order_data_filtered: ps.DataFrame) -> ps.DataFrame: + return order_data_filtered + + +def brazil_volume_by_year(processed: ps.DataFrame) -> ps.DataFrame: + return processed.groupBy("o_year").agg( + F.sum("volume").alias("sum_volume"), F.sum("brazil_volume").alias("sum_brazil_volume") + ) + + +def final_data(brazil_volume_by_year: ps.DataFrame) -> pd.DataFrame: + return brazil_volume_by_year.toPandas() diff --git a/examples/spark/tpc-h/run.py b/examples/spark/tpc-h/run.py new file mode 100644 index 000000000..2649cff8b --- /dev/null +++ b/examples/spark/tpc-h/run.py @@ -0,0 +1,52 @@ +import click + +exec("from hamilton.experimental import h_spark") +import csv_data_loaders +import pyspark +import query_1 +import query_8 +import query_12 + +from hamilton import base, driver + +QUERIES = {"query_1": query_1, "query_8": query_8, "query_12": query_12} + + +def run_query(query: str, data_dir: str, visualize: bool = True): + """Runs the given query""" + + dr = ( + driver.Builder() + .with_modules(QUERIES[query], csv_data_loaders) + .with_adapter(base.DefaultAdapter()) + .build() + ) + spark = pyspark.sql.SparkSession.builder.getOrCreate() + if visualize: + dr.visualize_execution( + ["final_data"], f"./dag-{query}", {}, inputs={"data_dir": data_dir, "spark": spark} + ) + df = dr.execute(["final_data"], inputs={"data_dir": data_dir, "spark": spark})["final_data"] + print(df) + + +@click.command() +@click.option("--data-dir", type=str, help="Base directory for data", required=True) +@click.option( + "--which", type=click.Choice(list(QUERIES.keys())), help="Which query to run", required=True +) +@click.option( + "--visualize", + type=bool, + help="Whether to visualize the execution", + is_flag=True, +) +def run_tpch_query(data_dir: str, which: str, visualize: bool): + """Placeholder function for running TPCH query""" + # Place logic here for running the TPCH query with the given 'which' value + click.echo(f"Running TPCH query: {which}") + run_query(which, data_dir, visualize) + + +if __name__ == "__main__": + run_tpch_query() diff --git a/graph_adapter_tests/h_spark/test_h_spark.py b/graph_adapter_tests/h_spark/test_h_spark.py index feecf4c1a..9ec4a0485 100644 --- a/graph_adapter_tests/h_spark/test_h_spark.py +++ b/graph_adapter_tests/h_spark/test_h_spark.py @@ -5,18 +5,31 @@ import pyspark.pandas as ps import pytest from pyspark import Row -from pyspark.sql import SparkSession, types +from pyspark.sql import Column, DataFrame, SparkSession, types from pyspark.sql.functions import column from hamilton import base, driver, htypes, node from hamilton.experimental import h_spark -from .resources import example_module, pyspark_udfs, smoke_screen_module +from .resources import example_module, smoke_screen_module +from .resources.spark import ( + basic_spark_dag, + pyspark_udfs, + spark_dag_external_dependencies, + spark_dag_mixed_pyspark_pandas_udfs, + spark_dag_multiple_with_columns, + spark_dag_pyspark_udfs, +) @pytest.fixture(scope="module") def spark_session(): - spark = SparkSession.builder.getOrCreate() + spark = ( + SparkSession.builder.master("local") + .appName("spark session") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() + ) yield spark spark.stop() @@ -89,46 +102,61 @@ def test_smoke_screen_module(spark_session): assert df["series_with_start_date_end_date"].iloc[0] == "date_20200101_date_20220801" -spark = SparkSession.builder.master("local[1]").getOrCreate() - -pandas_df = pd.DataFrame({"spend": [10, 10, 20, 40, 40, 50], "signups": [1, 10, 50, 100, 200, 400]}) -spark_df = spark.createDataFrame(pandas_df) - - @pytest.mark.parametrize( - "input, expected", + "input_and_expected_fn", [ - ({}, (None, {})), - ({"a": 1}, (None, {"a": 1})), - ({"a": spark_df}, (spark_df, {})), - ({"a": spark_df, "b": 1}, (spark_df, {"b": 1})), + (lambda df: ({}, (None, {}))), + (lambda df: ({"a": 1}, (None, {"a": 1}))), + (lambda df: ({"a": df}, (df, {}))), + (lambda df: ({"a": df, "b": 1}, (df, {"b": 1}))), ], ids=["no_kwargs", "one_plain_kwarg", "one_df_kwarg", "one_df_kwarg_and_one_plain_kwarg"], ) -def test__inspect_kwargs(input, expected): +def test__inspect_kwargs(input_and_expected_fn, spark_session): """A unit test for inspect_kwargs.""" - assert h_spark._inspect_kwargs(input) == expected + pandas_df = pd.DataFrame( + {"spend": [10, 10, 20, 40, 40, 50], "signups": [1, 10, 50, 100, 200, 400]} + ) + df = spark_session.createDataFrame(pandas_df) + input_, expected = input_and_expected_fn(df) + assert h_spark._inspect_kwargs(input_) == expected -def test__get_pandas_annotations(): +def test__get_pandas_annotations_no_pandas(): """Unit test for _get_pandas_annotations().""" def no_pandas(a: int, b: float) -> float: return a * b + assert h_spark._get_pandas_annotations(node.Node.from_fn(no_pandas), {}) == { + "a": False, + "b": False, + } + + +def test__get_pandas_annotations_with_pandas(): def with_pandas(a: pd.Series) -> pd.Series: return a * 2 - def with_pandas_and_other_default(a: pd.Series, b: int = 2) -> pd.Series: + assert h_spark._get_pandas_annotations(node.Node.from_fn(with_pandas), {}) == {"a": True} + + +def test__get_pandas_annotations_with_pandas_and_other_default(): + def with_pandas_and_other_default(a: pd.Series, b: int) -> pd.Series: return a * b + assert h_spark._get_pandas_annotations( + node.Node.from_fn(with_pandas_and_other_default), {"b": 2} + ) == {"a": True} + + +def test__get_pandas_annotations_with_pandas_and_other_default_and_one_more(): def with_pandas_and_other_default_with_one_more(a: pd.Series, c: int, b: int = 2) -> pd.Series: - return a * b + return a * b * c - assert h_spark._get_pandas_annotations(no_pandas) == {"a": False, "b": False} - assert h_spark._get_pandas_annotations(with_pandas) == {"a": True} - assert h_spark._get_pandas_annotations(with_pandas_and_other_default) == {"a": True} - assert h_spark._get_pandas_annotations(with_pandas_and_other_default_with_one_more) == { + assert h_spark._get_pandas_annotations( + node.Node.from_fn(with_pandas_and_other_default_with_one_more), {"b": 2} + ) == { "a": True, "c": False, } @@ -136,21 +164,51 @@ def with_pandas_and_other_default_with_one_more(a: pd.Series, c: int, b: int = 2 def test__bind_parameters_to_callable(): """Unit test for _bind_parameters_to_callable().""" + actual_kwargs = {"a": 1, "b": 2} + df_columns = {"b"} + node_input_types = { + "a": (int, node.DependencyType.REQUIRED), + "b": (int, node.DependencyType.REQUIRED), + } + df_params, params_to_bind = h_spark._determine_parameters_to_bind( + actual_kwargs, df_columns, node_input_types, "test" + ) + assert isinstance(df_params["b"], Column) + assert params_to_bind == {"a": 1} + assert str(df_params["b"]) == str(column("b")) # hacky, but compare string representation. - def base_func(a: int, b: int) -> int: - return a + b - actual_kwargs = {"a": 1, "b": 2} +def test__bind_parameters_to_callable_with_defaults_provided(): + """Unit test for _bind_parameters_to_callable().""" + actual_kwargs = {"a": 1, "b": 2, "c": 2} df_columns = {"b"} - node_input_types = {"a": (int,), "b": (int,)} - mod_func, df_params = h_spark._bind_parameters_to_callable( - actual_kwargs, df_columns, base_func, node_input_types, "test" + node_input_types = { + "a": (int, node.DependencyType.REQUIRED), + "b": (int, node.DependencyType.REQUIRED), + "c": (int, node.DependencyType.OPTIONAL), + } + df_params, params_to_bind = h_spark._determine_parameters_to_bind( + actual_kwargs, df_columns, node_input_types, "test" ) - import inspect + assert isinstance(df_params["b"], Column) + assert params_to_bind == {"a": 1, "c": 2} + assert str(df_params["b"]) == str(column("b")) # hacky, but compare string representation. - sig = inspect.signature(mod_func) - assert sig.parameters["a"].default == 1 - assert sig.parameters["b"].default == inspect.Parameter.empty + +def test__bind_parameters_to_callable_with_defaults_not_provided(): + """Unit test for _bind_parameters_to_callable().""" + actual_kwargs = {"a": 1, "b": 2, "c": 2} + df_columns = {"b"} + node_input_types = { + "a": (int, node.DependencyType.REQUIRED), + "b": (int, node.DependencyType.REQUIRED), + "c": (int, node.DependencyType.OPTIONAL), + } + df_params, params_to_bind = h_spark._determine_parameters_to_bind( + actual_kwargs, df_columns, node_input_types, "test" + ) + assert isinstance(df_params["b"], Column) + assert params_to_bind == {"a": 1, "c": 2} assert str(df_params["b"]) == str(column("b")) # hacky, but compare string representation. @@ -161,18 +219,8 @@ def base_func(a: int, b: int) -> int: return a + b base_spark_df = spark_session.createDataFrame(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) - node_ = node.Node( - "test", - int, - "", - base_func, - input_types={ - "a": (int, node.DependencyType.REQUIRED), - "b": (int, node.DependencyType.REQUIRED), - }, - ) - - new_df = h_spark._lambda_udf(base_spark_df, node_, base_func, {}) + node_ = node.Node.from_fn(base_func) + new_df = h_spark._lambda_udf(base_spark_df, node_, {}) assert new_df.collect() == [Row(a=1, b=4, test=5), Row(a=2, b=5, test=7), Row(a=3, b=6, test=9)] @@ -183,18 +231,9 @@ def base_func(a: pd.Series, b: pd.Series) -> htypes.column[pd.Series, int]: return a + b base_spark_df = spark_session.createDataFrame(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) - node_ = node.Node( - "test", - htypes.column[pd.Series, int], - "", - base_func, - input_types={ - "a": (int, node.DependencyType.REQUIRED), - "b": (int, node.DependencyType.REQUIRED), - }, - ) - - new_df = h_spark._lambda_udf(base_spark_df, node_, base_func, {}) + node_ = node.Node.from_fn(base_func) + + new_df = h_spark._lambda_udf(base_spark_df, node_, {}) assert new_df.collect() == [Row(a=1, b=4, test=5), Row(a=2, b=5, test=7), Row(a=3, b=6, test=9)] @@ -205,22 +244,13 @@ def base_func(a: pd.Series, b: int) -> htypes.column[pd.Series, int]: return a + b base_spark_df = spark_session.createDataFrame(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) - node_ = node.Node( - "test", - htypes.column[pd.Series, int], - "", - base_func, - input_types={ - "a": (int, node.DependencyType.REQUIRED), - "b": (int, node.DependencyType.REQUIRED), - }, - ) + node_ = node.Node.from_fn(base_func) with pytest.raises(ValueError): - h_spark._lambda_udf(base_spark_df, node_, base_func, {"a": 1}) + h_spark._lambda_udf(base_spark_df, node_, {"a": 1}) -def test_smoke_screen_udf_graph_adatper(spark_session): +def test_smoke_screen_udf_graph_adapter(spark_session): """Tests that we can run the PySparkUDFGraphAdapter on a simple graph. THe graph has a pandas UDF, a plain UDF that depends on the output of the pandas UDF, and @@ -273,13 +303,8 @@ def test_python_to_spark_type_invalid(invalid_python_type): (bytes, types.BinaryType()), ], ) -def test_get_spark_type_basic_types( - dummy_kwargs, dummy_df, dummy_udf, return_type, expected_spark_type -): - assert ( - h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, return_type) - == expected_spark_type - ) +def test_get_spark_type_basic_types(return_type, expected_spark_type): + assert h_spark.get_spark_type(return_type) == expected_spark_type # 2. Lists of basic Python types @@ -294,14 +319,9 @@ def test_get_spark_type_basic_types( (bytes, types.ArrayType(types.BinaryType())), ], ) -def test_get_spark_type_list_types( - dummy_kwargs, dummy_df, dummy_udf, return_type, expected_spark_type -): +def test_get_spark_type_list_types(return_type, expected_spark_type): return_type = list[return_type] # type: ignore - assert ( - h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, return_type) - == expected_spark_type - ) + assert h_spark.get_spark_type(return_type) == expected_spark_type # 3. Numpy types (assuming you have a numpy_to_spark_type function that handles these) @@ -313,24 +333,19 @@ def test_get_spark_type_list_types( (np.bool_, types.BooleanType()), ], ) -def test_get_spark_type_numpy_types( - dummy_kwargs, dummy_df, dummy_udf, return_type, expected_spark_type -): - assert ( - h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, return_type) - == expected_spark_type - ) +def test_get_spark_type_numpy_types(return_type, expected_spark_type): + assert h_spark.get_spark_type(return_type) == expected_spark_type # 4. Unsupported types @pytest.mark.parametrize( "unsupported_return_type", [dict, set, tuple] # Add other unsupported types as needed ) -def test_get_spark_type_unsupported(dummy_kwargs, dummy_df, dummy_udf, unsupported_return_type): +def test_get_spark_type_unsupported(unsupported_return_type): with pytest.raises( ValueError, match=f"Currently unsupported return type {unsupported_return_type}." ): - h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, unsupported_return_type) + h_spark.get_spark_type(unsupported_return_type) # Dummy values for the tests @@ -340,8 +355,8 @@ def dummy_kwargs(): @pytest.fixture -def dummy_df(): - return spark.createDataFrame(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) +def dummy_df(spark_session): + return spark_session.createDataFrame(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) @pytest.fixture @@ -350,3 +365,410 @@ def dummyfunc(x: int) -> int: return x return dummyfunc + + +def test_base_spark_executor_end_to_end(spark_session): + # TODO -- make this simpler to call, and not require all these constructs + dr = ( + driver.Builder() + .with_modules(basic_spark_dag) + .with_adapter(base.SimplePythonGraphAdapter(base.DictResult())) + .build() + ) + # dr.visualize_execution( + # ["processed_df_as_pandas"], "./out", {}, inputs={"spark_session": spark_session} + # ) + df = dr.execute(["processed_df_as_pandas"], inputs={"spark_session": spark_session})[ + "processed_df_as_pandas" + ] + expected_data = { + "a_times_key": [2, 10, 24, 44, 70], + "b_times_key": [5, 16, 33, 56, 85], + "a_plus_b_plus_c": [10.5, 20.0, 29.5, 39.0, 48.5], + } + expected_df = pd.DataFrame(expected_data) + pd.testing.assert_frame_equal(df, expected_df, check_names=False, check_dtype=False) + + +def test_base_spark_executor_end_to_end_with_mode_select(spark_session): + # TODO -- make this simpler to call, and not require all these constructs + dr = ( + driver.Builder() + .with_modules(basic_spark_dag) + .with_adapter(base.SimplePythonGraphAdapter(base.DictResult())) + .with_config({"mode": "select"}) + .build() + ) + # dr.visualize_execution( + # ["processed_df_as_pandas"], "./out", {}, inputs={"spark_session": spark_session} + # ) + df = dr.execute(["processed_df_as_pandas"], inputs={"spark_session": spark_session})[ + "processed_df_as_pandas" + ] + expected_data = { + "a_times_key": [2, 10, 24, 44, 70], + "a_plus_b_plus_c": [10.5, 20.0, 29.5, 39.0, 48.5], + } + expected_df = pd.DataFrame(expected_data) + pd.testing.assert_frame_equal(df, expected_df, check_names=False, check_dtype=False) + + +def test_base_spark_executor_end_to_end_external_dependencies(spark_session): + # TODO -- make this simpler to call, and not require all these constructs + dr = ( + driver.Builder() + .with_modules(spark_dag_external_dependencies) + .with_adapter(base.SimplePythonGraphAdapter(base.DictResult())) + .build() + ) + dfs = dr.execute( + ["processed_df_as_pandas"], + inputs={"spark_session": spark_session}, + ) + + expected_df = pd.DataFrame({"a": [2, 3, 4, 5], "b": [4, 6, 8, 10]}) + processed_df_as_pandas = pd.DataFrame(dfs["processed_df_as_pandas"]) + pd.testing.assert_frame_equal( + processed_df_as_pandas, expected_df, check_names=False, check_dtype=False + ) + + +def test_base_spark_executor_end_to_end_multiple_with_columns(spark_session): + dr = ( + driver.Builder() + .with_modules(spark_dag_multiple_with_columns) + .with_adapter(base.SimplePythonGraphAdapter(base.DictResult())) + .build() + ) + df = dr.execute(["final"], inputs={"spark_session": spark_session})["final"].sort_index(axis=1) + + expected_df = pd.DataFrame( + { + "d_raw": [1, 4, 7, 10], + "e_raw": [2, 5, 8, 11], + "f_raw": [5, 10, 15, 20], + "d": [6, 9, 12, 15], + "f": [17.5, 35.0, 52.5, 70.0], + "e": [12.3, 18.299999, 24.299999, 30.299999], + "multiply_d_e_f_key": [1291.5, 11529.0, 45927.0, 127260.0], + "key": [1, 2, 3, 4], + "a_times_key": [2, 10, 24, 44], + "b_times_key": [5, 16, 33, 56], + "a_plus_b_plus_c": [10.5, 20.0, 29.5, 39.0], + } + ).sort_index(axis=1) + pd.testing.assert_frame_equal(df, expected_df, check_names=False, check_dtype=False) + + +def _only_pyspark_dataframe_parameter(foo: DataFrame) -> DataFrame: + ... + + +def _no_pyspark_dataframe_parameter(foo: int) -> int: + ... + + +def _one_pyspark_dataframe_parameter(foo: DataFrame, bar: int) -> DataFrame: + ... + + +def _two_pyspark_dataframe_parameters(foo: DataFrame, bar: int, baz: DataFrame) -> DataFrame: + ... + + +@pytest.mark.parametrize( + "fn,requested_parameter,expected", + [ + (_only_pyspark_dataframe_parameter, "foo", "foo"), + (_one_pyspark_dataframe_parameter, "foo", "foo"), + (_one_pyspark_dataframe_parameter, None, "foo"), + (_two_pyspark_dataframe_parameters, "foo", "foo"), + (_two_pyspark_dataframe_parameters, "baz", "baz"), + ], +) +def test_derive_dataframe_parameter_succeeds(fn, requested_parameter, expected): + assert h_spark.derive_dataframe_parameter_from_fn(fn, requested_parameter) == expected + n = node.Node.from_fn(fn) + assert h_spark.derive_dataframe_parameter_from_node(n, requested_parameter) == expected + + +@pytest.mark.parametrize( + "fn,requested_parameter", + [ + (_no_pyspark_dataframe_parameter, "foo"), + (_no_pyspark_dataframe_parameter, None), + (_one_pyspark_dataframe_parameter, "baz"), + (_two_pyspark_dataframe_parameters, "bar"), + (_two_pyspark_dataframe_parameters, None), + ], +) +def test_derive_dataframe_parameter_fails(fn, requested_parameter): + with pytest.raises(ValueError): + h_spark.derive_dataframe_parameter_from_fn(fn, requested_parameter) + n = node.Node.from_fn(fn) + h_spark.derive_dataframe_parameter_from_node(n, requested_parameter) + + +def test_prune_nodes_no_select(): + nodes = [ + node.Node.from_fn(fn) for fn in [basic_spark_dag.a, basic_spark_dag.b, basic_spark_dag.c] + ] + select = None + assert {n for n in h_spark.prune_nodes(nodes, select)} == set(nodes) + + +def test_prune_nodes_single_select(): + nodes = [ + node.Node.from_fn(fn) for fn in [basic_spark_dag.a, basic_spark_dag.b, basic_spark_dag.c] + ] + select = ["a", "b"] + assert {n for n in h_spark.prune_nodes(nodes, select)} == set(nodes[0:2]) + + +def test_generate_nodes_invalid_select(): + dec = h_spark.with_columns( + basic_spark_dag.a, + basic_spark_dag.b, + basic_spark_dag.c, + select=["d"], # not a node + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], + ) + with pytest.raises(ValueError): + + def df_as_pandas(df: DataFrame) -> pd.DataFrame: + return df.toPandas() + + dec.generate_nodes(df_as_pandas, {}) + + +def test_with_columns_generate_nodes_no_select(): + dec = h_spark.with_columns( + basic_spark_dag.a, + basic_spark_dag.b, + basic_spark_dag.c, + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], + ) + + def df_as_pandas(df: DataFrame) -> pd.DataFrame: + return df.toPandas() + + nodes = dec.generate_nodes(df_as_pandas, {}) + nodes_by_names = {n.name: n for n in nodes} + assert set(nodes_by_names.keys()) == { + "df_as_pandas.a", + "df_as_pandas.b", + "df_as_pandas.c", + "df_as_pandas", + } + + +def test_with_columns_generate_nodes_select(): + dec = h_spark.with_columns( + basic_spark_dag.a, + basic_spark_dag.b, + basic_spark_dag.c, + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], + select=["c"], + ) + + def df_as_pandas(df: DataFrame) -> pd.DataFrame: + return df.toPandas() + + nodes = dec.generate_nodes(df_as_pandas, {}) + nodes_by_names = {n.name: n for n in nodes} + assert set(nodes_by_names.keys()) == {"df_as_pandas.c", "df_as_pandas"} + + +def test_with_columns_generate_nodes_select_mode_select(): + dec = h_spark.with_columns( + basic_spark_dag.a, + basic_spark_dag.b, + basic_spark_dag.c, + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], + select=["c"], + mode="select", + ) + + def df_as_pandas(df: DataFrame) -> pd.DataFrame: + return df.toPandas() + + nodes = dec.generate_nodes(df_as_pandas, {}) + nodes_by_names = {n.name: n for n in nodes} + assert set(nodes_by_names.keys()) == {"df_as_pandas.c", "df_as_pandas", "df_as_pandas._select"} + + +def test_with_columns_generate_nodes_specify_namespace(): + dec = h_spark.with_columns( + basic_spark_dag.a, + basic_spark_dag.b, + basic_spark_dag.c, + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], + namespace="foo", + ) + + def df_as_pandas(df: DataFrame) -> pd.DataFrame: + return df.toPandas() + + nodes = dec.generate_nodes(df_as_pandas, {}) + nodes_by_names = {n.name: n for n in nodes} + assert set(nodes_by_names.keys()) == {"foo.a", "foo.b", "foo.c", "df_as_pandas"} + + +def test__format_pandas_udf(): + assert ( + h_spark._format_pandas_udf("foo", ["a", "b"]).strip() + == "def foo(a: pd.Series, b: pd.Series) -> pd.Series:\n" + " return partial_fn(a=a, b=b)" + ) + + +def test__format_standard_udf(): + assert ( + h_spark._format_udf("foo", ["b", "a"]).strip() == "def foo(b, a):\n" + " return partial_fn(b=b, a=a)" + ) + + +def test_sparkify_node(): + def foo( + a_from_upstream: pd.Series, b_from_upstream: pd.Series, c_from_df: pd.Series, d_fixed: int + ) -> htypes.column[pd.Series, int]: + return a_from_upstream + b_from_upstream + c_from_df + d_fixed + + node_ = node.Node.from_fn(foo) + sparkified = h_spark.sparkify_node_with_udf( + node_, + "df_upstream", + "df_base", + None, + {"a_from_upstream", "b_from_upstream"}, + {"c_from_df"}, + ) + # Superset of all the original nodes except the ones from the dataframe + # (as we already have that) both the physical and the logical dependencies + assert set(sparkified.input_types) == { + "a_from_upstream", + "b_from_upstream", + "d_fixed", + "df_base", + "df_upstream", + } + + +def test_pyspark_mixed_pandas_udfs_end_to_end(): + # TODO -- make this simpler to call, and not require all these constructs + dr = ( + driver.Builder() + .with_modules(spark_dag_mixed_pyspark_pandas_udfs) + .with_adapter(base.SimplePythonGraphAdapter(base.DictResult())) + .build() + ) + # dr.visualize_execution( + # ["processed_df_as_pandas_dataframe_with_injected_dataframe"], + # "./out", + # {}, + # inputs={"spark_session": spark_session}, + # ) + results = dr.execute( + ["processed_df_as_pandas_dataframe_with_injected_dataframe", "processed_df_as_pandas"], + inputs={"spark_session": spark_session}, + ) + processed_df_as_pandas = results["processed_df_as_pandas"] + processed_df_as_pandas_dataframe_with_injected_dataframe = results[ + "processed_df_as_pandas_dataframe_with_injected_dataframe" + ] + expected_data = { + "a_times_key": [2, 10, 24, 44, 70], + "b_times_key": [5, 16, 33, 56, 85], + "a_plus_b_plus_c": [10.5, 20.0, 29.5, 39.0, 48.5], + } + expected_df = pd.DataFrame(expected_data) + pd.testing.assert_frame_equal( + processed_df_as_pandas, expected_df, check_names=False, check_dtype=False + ) + pd.testing.assert_frame_equal( + processed_df_as_pandas_dataframe_with_injected_dataframe, + expected_df, + check_names=False, + check_dtype=False, + ) + + +def test_just_pyspark_udfs_end_to_end(): + # TODO -- make this simpler to call, and not require all these constructs + dr = ( + driver.Builder() + .with_modules(spark_dag_pyspark_udfs) + .with_adapter(base.SimplePythonGraphAdapter(base.DictResult())) + .build() + ) + # dr.visualize_execution( + # ["processed_df_as_pandas_with_injected_dataframe", "processed_df_as_pandas"], + # "./out", + # {}, + # ) + results = dr.execute( + ["processed_df_as_pandas_with_injected_dataframe", "processed_df_as_pandas"] + ) + processed_df_as_pandas_with_injected_dataframe = results[ + "processed_df_as_pandas_with_injected_dataframe" + ] + processed_df_as_pandas = results["processed_df_as_pandas"] + expected_data = { + "a_times_key": [2, 10, 24, 44, 70], + "b_times_key": [5, 16, 33, 56, 85], + "a_plus_b_plus_c": [10.5, 20.0, 29.5, 39.0, 48.5], + } + expected_df = pd.DataFrame(expected_data) + pd.testing.assert_frame_equal( + processed_df_as_pandas, expected_df, check_names=False, check_dtype=False + ) + pd.testing.assert_frame_equal( + processed_df_as_pandas_with_injected_dataframe, + expected_df, + check_names=False, + check_dtype=False, + ) + + +# is default +def pyspark_fn_1(foo: DataFrame) -> DataFrame: + pass + + +# is default +def pyspark_fn_2(foo: DataFrame, bar: int) -> DataFrame: + pass + + +def not_pyspark_fn(foo: DataFrame, bar: DataFrame) -> DataFrame: + pass + + +@pytest.mark.parametrize( + "fn,expected", [(pyspark_fn_1, True), (pyspark_fn_2, True), (not_pyspark_fn, False)] +) +def test_is_default_pyspark_node(fn, expected): + node_ = node.Node.from_fn(fn) + assert h_spark.require_columns.is_default_pyspark_udf(node_) == expected + + +def fn_test_initial_schema_1(a: int, b: int) -> int: + return a + b + + +def fn_test_initial_schema_2(fn_test_initial_schema_1: int, c: int = 1) -> int: + return fn_test_initial_schema_1 + c + + +def test_create_selector_node(spark_session): + selector_node = h_spark.with_columns.create_selector_node("foo", ["a", "b"], "select") + assert selector_node.name == "select" + pandas_df = pd.DataFrame( + {"a": [10, 10, 20, 40, 40, 50], "b": [1, 10, 50, 100, 200, 400], "c": [1, 2, 3, 4, 5, 6]} + ) + df = spark_session.createDataFrame(pandas_df) + transformed = selector_node(foo=df).toPandas() + pd.testing.assert_frame_equal( + transformed, pandas_df[["a", "b"]], check_names=False, check_dtype=False + ) diff --git a/hamilton/execution/graph_functions.py b/hamilton/execution/graph_functions.py index 89d9d999f..a6837c6b5 100644 --- a/hamilton/execution/graph_functions.py +++ b/hamilton/execution/graph_functions.py @@ -12,6 +12,12 @@ def topologically_sort_nodes(nodes: List[node.Node]) -> List[node.Node]: """Topologically sorts a list of nodes based on their dependencies. + Note that we bypass utilizing the preset dependencies/depended_on_by attributes of the node, + as we may want to use this before these nodes get put in a function graph. + + Thus we compute our own dependency map... + Note that this assumes that the nodes are continuous -- if there is a hidden dependency that + connects them, this has no way of knowing about it. TODO -- use python graphlib when we no longer have to support 3.7/3.8. @@ -20,24 +26,36 @@ def topologically_sort_nodes(nodes: List[node.Node]) -> List[node.Node]: :param nodes: Nodes to sort :return: Nodes in sorted order """ - - in_degrees = {node_.name: len(node_.dependencies) for node_ in nodes} + node_name_map = {node_.name: node_ for node_ in nodes} + depended_on_by_map = {} + dependency_map = {} + for node_ in nodes: + dependency_map[node_.name] = [] + for dep in node_.input_types: + # if the dependency is not here, we don't want to count it + # that means it depends on something outside the set of nodes we're sorting + if dep not in node_name_map: + continue + dependency_map[node_.name].append(dep) + if dep not in depended_on_by_map: + depended_on_by_map[dep] = [] + depended_on_by_map[dep].append(node_) + + in_degrees = {node_.name: len(dependency_map.get(node_.name, [])) for node_ in nodes} # TODO -- determine what happens if nodes have dependencies that aren't present - sources = [node_ for node_ in nodes if len(node_.dependencies) == 0] + sources = [node_ for node_ in nodes if in_degrees[node_.name] == 0] queue = [] for source in sources: queue.append(source) - sorted_nodes = [] while len(queue) > 0: node_ = queue.pop(0) sorted_nodes.append(node_) - for next_node in node_.depended_on_by: + for next_node in depended_on_by_map.get(node_.name, []): if next_node.name in in_degrees: in_degrees[next_node.name] -= 1 if in_degrees[next_node.name] == 0: queue.append(next_node) - return sorted_nodes diff --git a/hamilton/experimental/h_spark.py b/hamilton/experimental/h_spark.py index 0883fb692..817906119 100644 --- a/hamilton/experimental/h_spark.py +++ b/hamilton/experimental/h_spark.py @@ -1,17 +1,26 @@ +import abc +import dataclasses import functools import inspect import logging import sys -from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union +from types import CodeType, FunctionType, ModuleType +from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Type, Union import numpy as np import pandas as pd import pyspark.pandas as ps -from pyspark.sql import DataFrame, dataframe, types +from pyspark.sql import Column, DataFrame, SparkSession, dataframe, types from pyspark.sql.functions import column, lit, pandas_udf, udf -from hamilton import base, htypes, node -from hamilton.node import DependencyType +from hamilton import base, htypes, node, registry +from hamilton.execution import graph_functions +from hamilton.function_modifiers import base as fm_base +from hamilton.function_modifiers import subdag +from hamilton.function_modifiers.recursive import assign_namespace +from hamilton.htypes import custom_subclass_check +from hamilton.io import utils +from hamilton.io.data_adapters import DataLoader logger = logging.getLogger(__name__) @@ -206,9 +215,7 @@ def python_to_spark_type(python_type: Type[Union[int, float, bool, str, bytes]]) _list = (list[int], list[float], list[bool], list[str], list[bytes]) -def get_spark_type( - actual_kwargs: dict, df: DataFrame, hamilton_udf: Callable, return_type: Any -) -> types.DataType: +def get_spark_type(return_type: Any) -> types.DataType: if return_type in (int, float, bool, str, bytes): return python_to_spark_type(return_type) elif return_type in _list: @@ -216,36 +223,31 @@ def get_spark_type( elif hasattr(return_type, "__module__") and getattr(return_type, "__module__") == "numpy": return numpy_to_spark_type(return_type) else: - logger.debug(f"{inspect.signature(hamilton_udf)}, {actual_kwargs}, {df.columns}") raise ValueError( f"Currently unsupported return type {return_type}. " f"Please create an issue or PR to add support for this type." ) -def _get_pandas_annotations(hamilton_udf: Callable) -> Dict[str, bool]: +def _get_pandas_annotations(node_: node.Node, bound_parameters: Dict[str, Any]) -> Dict[str, bool]: """Given a function, return a dictionary of the parameters that are annotated as pandas series. :param hamilton_udf: the function to check. :return: dictionary of parameter names to boolean indicating if they are pandas series. """ - new_signature = inspect.signature(hamilton_udf) - new_sig_parameters = dict(new_signature.parameters) - pandas_annotation = { - name: param.annotation == pd.Series - for name, param in new_sig_parameters.items() - if param.default == inspect.Parameter.empty # bound parameters will have a default value. + return { + name: type_ == pd.Series + for name, (type_, dep_type) in node_.input_types.items() + if name not in bound_parameters and dep_type == node.DependencyType.REQUIRED } - return pandas_annotation -def _bind_parameters_to_callable( +def _determine_parameters_to_bind( actual_kwargs: dict, df_columns: Set[str], - hamilton_udf: Callable, node_input_types: Dict[str, Tuple], node_name: str, -) -> Tuple[Callable, Dict[str, Any]]: +) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Function that we use to bind inputs to the function, or determine we should pull them from the dataframe. It does two things: @@ -259,24 +261,22 @@ def _bind_parameters_to_callable( :param hamilton_udf: the callable to bind to. :param node_input_types: the input types of the function. :param node_name: name of the node/function. - :return: a tuple of the callable and the dictionary of parameters to use for the callable. + :return: a tuple of the params that come from the dataframe and the parameters to bind. """ params_from_df = {} - for input_name in node_input_types.keys(): + bind_parameters = {} + for input_name, (type_, dep_type) in node_input_types.items(): if input_name in df_columns: params_from_df[input_name] = column(input_name) elif input_name in actual_kwargs and not isinstance(actual_kwargs[input_name], DataFrame): - hamilton_udf = functools.partial( - hamilton_udf, **{input_name: actual_kwargs[input_name]} - ) - elif node_input_types[input_name][1] == DependencyType.OPTIONAL: - pass - else: + bind_parameters[input_name] = actual_kwargs[input_name] + elif dep_type == node.DependencyType.REQUIRED: raise ValueError( - f"Cannot satisfy {node_name} with input types {node_input_types} against a dataframe with " + f"Cannot satisfy {node_name} with input types {node_input_types} against a " + f"dataframe with " f"columns {df_columns} and input kwargs {actual_kwargs}." ) - return hamilton_udf, params_from_df + return params_from_df, bind_parameters def _inspect_kwargs(kwargs: Dict[str, Any]) -> Tuple[DataFrame, Dict[str, Any]]: @@ -296,9 +296,77 @@ def _inspect_kwargs(kwargs: Dict[str, Any]) -> Tuple[DataFrame, Dict[str, Any]]: return df, actual_kwargs -def _lambda_udf( - df: DataFrame, node_: node.Node, hamilton_udf: Callable, actual_kwargs: Dict[str, Any] -) -> DataFrame: +def _format_pandas_udf(func_name: str, ordered_params: List[str]) -> str: + formatting_params = { + "name": func_name, + "return_type": "pd.Series", + "params": ", ".join([f"{param}: pd.Series" for param in ordered_params]), + "param_call": ", ".join([f"{param}={param}" for param in ordered_params]), + } + func_string = """ +def {name}({params}) -> {return_type}: + return partial_fn({param_call}) +""".format( + **formatting_params + ) + return func_string + + +def _format_udf(func_name: str, ordered_params: List[str]) -> str: + formatting_params = { + "name": func_name, + "params": ", ".join(ordered_params), + "param_call": ", ".join([f"{param}={param}" for param in ordered_params]), + } + func_string = """ +def {name}({params}): + return partial_fn({param_call}) +""".format( + **formatting_params + ) + return func_string + + +def _fabricate_spark_function( + node_: node.Node, + params_to_bind: Dict[str, Any], + params_from_df: Dict[str, Any], + pandas_udf: bool, +) -> FunctionType: + """Fabricates a spark compatible UDF. We have to do this as we don't actually have a funtion + with annotations to use, as its lambdas passed around by decorators. We may consider pushing + this upstreams so that everything can generate its own function, but for now this is the + easiest way to do it. + + The rules are different for pandas series and regular UDFs. + Pandas series have to: + - be Decorated with pandas_udf + - Have a return type of a pandas series + - Have a pandas series as the only input types + Regular UDFs have to: + - Have no annotations at all + + See https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.udf.html + and https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.pandas_udf.html + + :param node_: Node to place in a spark function + :param params_to_bind: Parameters to bind to the function -- these won't go into the UDF + :param params_from_df: Parameters to retrieve from the dataframe + :return: A function that can be used in a spark UDF + """ + partial_fn = functools.partial(node_.callable, **params_to_bind) + ordered_params = sorted(params_from_df) + func_name = node_.name.replace(".", "_") + if pandas_udf: + func_string = _format_pandas_udf(func_name, ordered_params) + else: + func_string = _format_udf(func_name, ordered_params) + module_code = compile(func_string, "", "exec") + func_code = [c for c in module_code.co_consts if isinstance(c, CodeType)][0] + return FunctionType(func_code, {**globals(), **{"partial_fn": partial_fn}}, func_name) + + +def _lambda_udf(df: DataFrame, node_: node.Node, actual_kwargs: Dict[str, Any]) -> DataFrame: """Function to create a lambda UDF for a function. This functions does the following: @@ -314,15 +382,16 @@ def _lambda_udf( :param actual_kwargs: the actual arguments to the function. :return: the dataframe with one more column representing the result of the UDF. """ - hamilton_udf, params_from_df = _bind_parameters_to_callable( - actual_kwargs, set(df.columns), hamilton_udf, node_.input_types, node_.name + params_from_df, params_to_bind = _determine_parameters_to_bind( + actual_kwargs, set(df.columns), node_.input_types, node_.name ) - pandas_annotation = _get_pandas_annotations(hamilton_udf) + pandas_annotation = _get_pandas_annotations(node_, params_to_bind) if any(pandas_annotation.values()) and not all(pandas_annotation.values()): raise ValueError( f"Currently unsupported function for {node_.name} with function signature:\n{node_.input_types}." ) elif all(pandas_annotation.values()): + hamilton_udf = _fabricate_spark_function(node_, params_to_bind, params_from_df, True) # pull from annotation here instead of tag. base_type, type_args = htypes.get_type_information(node_.type) logger.debug("PandasUDF: %s, %s, %s", node_.name, base_type, type_args) @@ -335,17 +404,18 @@ def _lambda_udf( if isinstance(type_arg, str): spark_return_type = type_arg # spark will handle converting it. else: - spark_return_type = get_spark_type(actual_kwargs, df, hamilton_udf, type_arg) - # remove because pyspark does not like extra function annotations - hamilton_udf.__annotations__["return"] = base_type + spark_return_type = get_spark_type(type_arg) spark_udf = pandas_udf(hamilton_udf, spark_return_type) else: + hamilton_udf = _fabricate_spark_function(node_, params_to_bind, params_from_df, False) logger.debug("RegularUDF: %s, %s", node_.name, node_.type) - spark_return_type = get_spark_type(actual_kwargs, df, hamilton_udf, node_.type) + spark_return_type = get_spark_type(node_.type) spark_udf = udf(hamilton_udf, spark_return_type) - return df.withColumn( - node_.name, spark_udf(*[_value for _name, _value in params_from_df.items()]) + out = df.withColumn( + node_.name, + spark_udf(*[_value for _name, _value in sorted(params_from_df.items())]), ) + return out class PySparkUDFGraphAdapter(base.SimplePythonDataFrameGraphAdapter): @@ -412,7 +482,7 @@ def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any: logger.debug("%s, %s", self.call_count, self.df_object) logger.debug("%s, Before, %s", node.name, self.df_object.columns) schema_length = len(df.schema) - df = _lambda_udf(self.df_object, node, node.callable, actual_kwargs) + df = _lambda_udf(self.df_object, node, actual_kwargs) assert node.name in df.columns, f"Error {node.name} not in {df.columns}" delta = len(df.schema) - schema_length if delta == 0: @@ -447,3 +517,742 @@ def build_result(self, **outputs: Dict[str, Any]) -> DataFrame: self.df_object = None self.original_schema = [] return result + + +def sparkify_node_with_udf( + node_: node.Node, + linear_df_dependency_name: str, + base_df_dependency_name: str, + base_df_dependency_param: Optional[str], + dependent_columns_in_group: Set[str], + dependent_columns_from_dataframe: Set[str], +) -> node.Node: + """ """ + """Turns a node into a spark node. This does the following: + 1. Makes it take the prior dataframe output as a dependency, in + conjunction to its current dependencies. This is so we can represent + the "logical" plan (the UDF-dependencies) as well as + the "physical plan" (linear, df operations) + 2. Adjusts the function to apply the specified UDF on the + dataframe, ignoring all inputs in column_dependencies + (which are only there to demonstrate lineage/make the DAG representative) + 3. Returns the resulting pyspark dataframe for downstream functions to use + + + :param node_: Node we're sparkifying + :param linear_df_dependency_name: Name of the linearly passed along dataframe dependency + :param base_df_dependency_name: Name of the base (parent) dataframe dependency. + this is only used if dependent_columns_from_dataframe is not empty + :param base_df_dendency_param: Name of the base (parent) dataframe dependency parameter, as known + by the node. This is only used if `pass_dataframe_as` is provided, which means that + dependent_columns_from_dataframe is empty. + :param dependent_columns_in_group: Columns on which this depends in the with_columns + :param dependent_columns_from_dataframe: Columns on which this depends in the + base (parent) dataframe that the with_columns is operating on + :return: + + """ + + def new_callable( + __linear_df_dependency_name: str = linear_df_dependency_name, + __base_df_dependency_name: str = base_df_dependency_name, + __dependent_columns_in_group: Set[str] = dependent_columns_in_group, + __dependent_columns_from_dataframe: Set[str] = dependent_columns_from_dataframe, + __base_df_dependency_param: str = base_df_dependency_param, + __node: node.Node = node_, + **kwargs, + ) -> ps.DataFrame: + """This is the new function that the node will call. + Note that this applies the hamilton UDF with *just* the input dataframe dependency, + ignoring the rest.""" + # gather the dataframe from the kwargs + df = kwargs[__linear_df_dependency_name] + kwargs = { + k: v + for k, v in kwargs.items() + if k not in __dependent_columns_from_dataframe + and k not in __dependent_columns_in_group + and k != __linear_df_dependency_name + and k != __base_df_dependency_name + } + return _lambda_udf(df, node_, kwargs) + + # Just extract the dependeency type + # TODO -- add something as a "logical" or "placeholder" dependency + new_input_types = { + # copy over the old ones + **{ + dep: value + for dep, value in node_.input_types.items() + if dep not in dependent_columns_from_dataframe + }, + # add the new one (from the previous) + linear_df_dependency_name: (DataFrame, node.DependencyType.REQUIRED), + # Then add all the others + # Note this might clobber the linear_df_dependency_name, but they'll be the same type + # If we have "logical" dependencies we'll want to be careful about the type + **{ + dep: (DataFrame, node.DependencyType.REQUIRED) + for dep, _ in node_.input_types.items() + if dep in dependent_columns_in_group + }, + } + + if base_df_dependency_param is not None and base_df_dependency_name in node_.input_types: + # In this case we want to add a dependency for visualization/lineage + new_input_types[base_df_dependency_name] = ( + DataFrame, + node.DependencyType.REQUIRED, + ) + if len(dependent_columns_from_dataframe) > 0: + new_input_types[base_df_dependency_name] = ( + DataFrame, + node.DependencyType.REQUIRED, + ) + return node_.copy_with(callabl=new_callable, input_types=new_input_types, typ=DataFrame) + + +def derive_dataframe_parameter( + param_types: Dict[str, Type], requested_parameter: str, location_name: Callable +) -> str: + dataframe_parameters = { + param for param, val in param_types.items() if custom_subclass_check(val, DataFrame) + } + if requested_parameter is not None: + if requested_parameter not in dataframe_parameters: + raise ValueError( + f"Requested parameter {requested_parameter} not found in " f"{location_name}" + ) + return requested_parameter + if len(dataframe_parameters) == 0: + raise ValueError( + f"No dataframe parameters found in: {location_name}. " + f"Received parameters: {param_types}. " + f"@with_columns must inject a dataframe parameter into the function." + ) + elif len(dataframe_parameters) > 1: + raise ValueError( + f"More than one dataframe parameter found in function: {location_name}. Please " + f"specify the desired one with the 'dataframe' parameter in @with_columns" + ) + assert len(dataframe_parameters) == 1 + return list(dataframe_parameters)[0] + + +def derive_dataframe_parameter_from_fn(fn: Callable, requested_parameter: str = None) -> str: + """Utility function to grab a pyspark dataframe parameter from a function. + Note if one is supplied it'll look for that. If none is, it will look to ensure + that there is only one dataframe parameter in the function. + + :param fn: Function to grab the dataframe parameter from + :param requested_parameter: If supplied, the name of the parameter to grab + :return: The name of the dataframe parameter + :raises ValueError: If no datframe parameter is supplied: + - if no dataframe parameter is found, or if more than one is found + if a requested parameter is supplied: + - if the requested parameter is not found + """ + sig = inspect.signature(fn) + parameters_with_types = {param.name: param.annotation for param in sig.parameters.values()} + return derive_dataframe_parameter(parameters_with_types, requested_parameter, fn.__qualname__) + + +def _derive_first_dataframe_parameter_from_fn(fn: Callable) -> str: + """Utility function to derive the first parameter from a function and assert + that it is annotated with a pyspark dataframe. + + :param fn: + :return: + """ + sig = inspect.signature(fn) + params = list(sig.parameters.items()) + if len(params) == 0: + raise ValueError( + f"Function {fn.__qualname__} has no parameters, but was " + f"decorated with with_columns. with_columns requires the first " + f"parameter to be a dataframe so we know how to wire dependencies." + ) + first_param_name, first_param_value = params[0] + if not custom_subclass_check(first_param_value.annotation, DataFrame): + raise ValueError( + f"Function {fn.__qualname__} has a first parameter {first_param_name} " + f"that is not a pyspark dataframe. Instead got: {first_param_value.annotation}." + f"with_columns requires the first " + f"parameter to be a dataframe so we know how to wire dependencies." + ) + return first_param_name + + +def derive_dataframe_parameter_from_node(node_: node.Node, requested_parameter: str = None) -> str: + """Derives the only/requested dataframe parameter from a node. + + :param node_: + :param requested_parameter: + :return: + """ + types_ = {key: value[0] for key, value in node_.input_types.items()} + originating_function_name = ( + node_.originating_functions[-1] if node_.originating_functions is not None else node_.name + ) + return derive_dataframe_parameter(types_, requested_parameter, originating_function_name) + + +def prune_nodes(nodes: List[node.Node], select: Optional[List[str]] = None) -> List[node.Node]: + """Prunes the nodes to only include those upstream from the select columns. + Conducts a depth-first search using the nodes `input_types` field. + + If select is None, we just assume all nodes should be included. + + :param nodes: Full set of nodes + :param select: Columns to select + :return: Pruned set of nodes + """ + if select is None: + return nodes + + node_name_map = {node_.name: node_ for node_ in nodes} + seen_nodes = set(select) + stack = list({node_name_map[col] for col in select if col in node_name_map}) + output = [] + while len(stack) > 0: + node_ = stack.pop() + output.append(node_) + for dep in node_.input_types: + if dep not in seen_nodes and dep in node_name_map: + dep_node = node_name_map[dep] + stack.append(dep_node) + seen_nodes.add(dep) + return output + + +class require_columns(fm_base.NodeTransformer): + """Decorator for spark that allows for the specification of columns to transform. + These are columns within a specific node in a decorator, enabling the user to make use of pyspark + transformations inside a with_columns group. Note that this will have no impact if it is not + decorating a node inside `with_columns`. + + Note that this currently does not work with other decorators, but it definitely could. + """ + + TRANSFORM_TARGET_TAG = "hamilton.spark.target" + TRANSFORM_COLUMNS_TAG = "hamilton.spark.columns" + + def __init__(self, *columns: str): + super(require_columns, self).__init__(target=None) + self._columns = columns + + def transform_node( + self, node_: node.Node, config: Dict[str, Any], fn: Callable + ) -> Collection[node.Node]: + """Generates nodes for the `@require_columns` decorator. + + This does two things, but does not fully prepare the node: + 1. It adds the columns as dependencies to the node + 2. Adds tags with relevant metadata for later use + + Note that, at this point, we don't actually know which columns will come from the + base dataframe, and which will come from the upstream nodes. This is handled in the + `with_columns` decorator, so for now, we need to give it enough information to topologically + sort/assign dependencies. + + :param node_: Node to transform + :param config: Configuration to use (unused here) + :return: + """ + param = derive_dataframe_parameter_from_node(node_) + with open("./debug.txt", "a") as f: + f.write(f"{node_.name}={param}\n") + + # This allows for injection of any extra parameters + def new_callable(__input_types=node_.input_types, **kwargs): + return node_.callable( + **{key: value for key, value in kwargs.items() if key in __input_types} + ) + + additional_input_types = { + param: (DataFrame, node.DependencyType.REQUIRED) + for param in self._columns + if param not in node_.input_types + } + node_out = node_.copy_with( + input_types={**node_.input_types, **additional_input_types}, + callabl=new_callable, + tags={ + require_columns.TRANSFORM_TARGET_TAG: param, + require_columns.TRANSFORM_COLUMNS_TAG: self._columns, + }, + ) + # if it returns a column, we just turn it into a withColumn expression + if custom_subclass_check(node_.type, Column): + + def transform_output(output: Column, kwargs: Dict[str, Any]) -> DataFrame: + return kwargs[param].withColumn(node_.name, output) + + node_out = node_out.transform_output(transform_output, DataFrame) + return [node_out] + + def validate(self, fn: Callable): + """Validates on the function, even though it operates on nodes. We can always loosen + this, but for now it should help the code stay readable. + + :param fn: Function this is decorating + :return: + """ + + _derive_first_dataframe_parameter_from_fn(fn) + + @staticmethod + def _extract_dataframe_params(node_: node.Node) -> List[str]: + """Extracts the dataframe parameters from a node. + + :param node_: Node to extract from + :return: List of dataframe parameters + """ + return [ + key + for key, value in node_.input_types.items() + if custom_subclass_check(value[0], DataFrame) + ] + + @staticmethod + def is_default_pyspark_udf(node_: node.Node) -> bool: + """Tells if a node is, by default, a pyspark UDF. This means: + 1. It has a single dataframe parameter + 2. That parameter name determines an upstream column name + + :param node_: Node to check + :return: True if it functions as a default pyspark UDF, false otherwise + """ + df_columns = require_columns._extract_dataframe_params(node_) + return len(df_columns) == 1 + + @staticmethod + def is_decorated_pyspark_udf(node_: node.Node): + """Tells if this is a decorated pyspark UDF. This means it has been + decorated by the `@transforms` decorator. + + :return: True if it can be run as part of a group, false otherwise + """ + if "hamilton.spark.columns" in node_.tags and "hamilton.spark.target" in node_.tags: + return True + return False + + @staticmethod + def sparkify_node( + node_: node.Node, + linear_df_dependency_name: str, + base_df_dependency_name: str, + base_df_param_name: Optional[str], + dependent_columns_from_upstream: Set[str], + dependent_columns_from_dataframe: Set[str], + ) -> node.Node: + """Transforms a pyspark node into a node that can be run as part of a `with_columns` group. + This is only for non-UDF nodes that have already been transformed by `@transforms`. + + :param node_: Node to transform + :param linear_df_dependency_name: Dependency on continaully modified dataframe (this will enable us + :param base_df_dependency_name: + :param dependent_columns_in_group: + :param dependent_columns_from_dataframe: + :return: The final node with correct dependencies + """ + transformation_target = node_.tags.get(require_columns.TRANSFORM_TARGET_TAG) + + # Note that the following does not use the reassign_columns function as we have + # special knowledge of the function -- E.G. that it doesn't need all the parameters + # we choose to pass it. Thus we can just make sure that we pass it the right one, + # and not worry about value-clashes in reassigning names (as there are all sorts of + # edge cases around the parameter name to be transformed). + + # We have only a few dependencies we truly need + # These are the linear_df_dependency_name (the dataframe that is being modified) + # as well as any non-dataframe arguments (E.G. the ones that aren't about to be added + # Note that the node comes with logical dependencies already, so we filter them out + def new_callable(__callable=node_.callable, **kwargs) -> Any: + new_kwargs = kwargs.copy() + new_kwargs[transformation_target] = kwargs[linear_df_dependency_name] + return __callable(**new_kwargs) + + # We start off with everything except the transformation target, as we're + # going to use the linear dependency for that (see the callable above) + new_input_types = { + key: value + for key, value in node_.input_types.items() + if key != transformation_target and key not in dependent_columns_from_dataframe + } + # Thus we put that linear dependency in + new_input_types[linear_df_dependency_name] = (DataFrame, node.DependencyType.REQUIRED) + # Then we go through all "logical" dependencies -- columns we want to add to make lineage + # look nice + for item in dependent_columns_from_upstream: + new_input_types[item] = (DataFrame, node.DependencyType.REQUIRED) + + # Then we see if we're trying to transform the base dataframe + # This means we're not referring to it as a column, and only happens with the + # `pass_dataframe_as` argument (which means the base_df_param_name is not None) + if transformation_target == base_df_param_name: + new_input_types[base_df_dependency_name] = ( + DataFrame, + node.DependencyType.REQUIRED, + ) + # Finally we create the new node and return it + node_ = node_.copy_with(callabl=new_callable, input_types=new_input_types) + return node_ + + +def _identify_upstream_dataframe_nodes(nodes: List[node.Node]) -> List[str]: + """Gives the upstream dataframe name. This is the only ps.DataFrame parameter not + produced from within the subdag. + + :param nodes: Nodes in the subdag + :return: The name of the upstream dataframe + """ + node_names = {node_.name for node_ in nodes} + df_deps = set() + + for node_ in nodes: + # In this case its a df node that is a linear dependency, so we don't count it + # Instead we count the columns it wants, as we have not yet created them TODO -- + # consider moving this validation afterwards so we don't have to do this check + df_dependencies = node_.tags.get( + require_columns.TRANSFORM_COLUMNS_TAG, + [ + dep + for dep, (type_, _) in node_.input_types.items() + if custom_subclass_check(type_, DataFrame) + ], + ) + for dependency in df_dependencies: + if dependency not in node_names: + df_deps.add(dependency) + return list(df_deps) + + +class with_columns(fm_base.NodeCreator): + def __init__( + self, + *load_from: Union[Callable, ModuleType], + columns_to_pass: List[str] = None, + pass_dataframe_as: str = None, + select: List[str] = None, + namespace: str = None, + mode: str = "append", + ): + """Initializes a with_columns decorator for spark. This allows you to efficiently run + groups of map operations on a dataframe, represented as pandas/primitives UDFs. This + effectively "linearizes" compute -- meaning that a DAG of map operations can be run + as a set of .withColumn operations on a single dataframe -- ensuring that you don't have + to do a complex `extract` then `join` process on spark, which can be inefficient. + + Here's an example of calling it -- if you've seen `@subdag`, you should be familiar with + the concepts: + + .. code-block:: python + # my_module.py + def a(a_from_df: pd.Series) -> pd.Series: + return _process(a) + + def b(b_from_df: pd.Series) -> pd.Series: + return _process(b) + + def a_plus_b(a_from_df: pd.Series, b_from_df: pd.Series) -> pd.Series: + return a + b + + + # the with_columns call + @with_columns( + load_from=[my_module], # Load from any module + columns_to_pass=["a_from_df", "b_from_df"], # The columns to pass from the dataframe to + # the subdag + select=["a", "b", "a_plus_b"], # The columns to select from the dataframe + ) + def final_df(initial_df: ps.DataFrame) -> ps.DataFrame: + # process, or just return unprocessed + ... + + You can think of the above as a series of withColumn calls on the dataframe, where the + operations are applied in topological order. This is significantly more efficient than + extracting out the columns, applying the maps, then joining, but *also* allows you to + express the operations individually, making it easy to unit-test and reuse. + + Note that the operation is "append", meaning that the columns that are selected are appended + onto the dataframe. We will likely add an option to have this be either "select" or "append" + mode. + + If the function takes multiple dataframes, the dataframe input to process will always be + the first one. This will be passed to the subdag, transformed, and passed back to the functions. + This follows the hamilton rule of reference by parameter name. To demonstarte this, in the code + above, the dataframe that is passed to the subdag is `initial_df`. That is transformed + by the subdag, and then returned as the final dataframe. + + You can read it as: + + "final_df is a function that transforms the upstream dataframe initial_df, running the transformations + from my_module. It starts with the columns a_from_df and b_from_df, and then adds the columns + a, b, and a_plus_b to the dataframe. It then returns the dataframe, and does some processing on it." + + + :param load_from: The functions that will be used to generate the group of map operations. + :param select: Columns to select from the transformation. If this is left blank it will + keep all columns in the subdag. + :param columns_to_pass: The initial schema of the dataframe. This is used to determine which + upstream inputs should be taken from the dataframe, and which shouldn't. Note that, if this is + left empty (and external_inputs is as well), we will assume that all dependencies come + from the dataframe. This cannot be used in conjunction with pass_dataframe_as. + :param pass_dataframe_as: The name of the dataframe that we're modifying, as known to the subdag. + If you pass this in, you are responsible for extracting columns out. If not provided, you have + to pass columns_to_pass in, and we will extract the columns out for you. + :param namespace: The namespace of the nodes, so they don't clash with the global namespace + and so this can be reused. If its left out, there will be no namespace (in which case you'll want + to be careful about repeating it/reusing the nodes in other parts of the DAG.) + :param mode: The mode of the operation. This can be either "append" or "select". + If it is "append", it will keep all columns in the dataframe. If it is "select", + it will only keep the columns in the dataframe from the `select` parameter. Note that, + if the `select` parameter is left blank, it will keep all columns in the dataframe + that are in the subdag (as that is the behavior of the `select` parameter. This + defaults to `append` + """ + self.subdag_functions = subdag.collect_functions(load_from) + self.select = select + self.initial_schema = columns_to_pass + if (pass_dataframe_as is not None and columns_to_pass is not None) or ( + pass_dataframe_as is None and columns_to_pass is None + ): + raise ValueError( + "You must specify only one of columns_to_pass and " + "pass_dataframe_as. " + "This is because specifying pass_dataframe_as injects into " + "the set of columns, allowing you to perform your own extraction" + "from the dataframe. We then execute all columns in the sbudag" + "in order, passing in that initial dataframe. If you want" + "to reference columns in your code, you'll have to specify " + "the set of initial columns, and allow the subdag decorator " + "to inject the dataframe through. The initial columns tell " + "us which parameters to take from that dataframe, so we can" + "feed the right data into the right columns." + ) + self.dataframe_subdag_param = pass_dataframe_as + self.namespace = namespace + self.upstream_dependency = dataframe + self.mode = mode + + @staticmethod + def _prep_nodes(initial_nodes: List[node.Node]) -> List[node.Node]: + """Prepares nodes by decorating "default" UDFs with transform. + This allows us to use the sparkify_node function in transforms + for both the default ones and the decorated ones. + + :param initial_nodes: Initial nodes to prepare + :return: Prepared nodes + """ + out = [] + for node_ in initial_nodes: + if require_columns.is_default_pyspark_udf(node_): + col = derive_dataframe_parameter_from_node(node_) + # todo -- wire through config/function correctly + # the col is the only dataframe paameter so it is the target node + (node_,) = require_columns(col).transform_node(node_, {}, node_.callable) + out.append(node_) + return out + + @staticmethod + def create_selector_node( + upstream_name: str, columns: List[str], node_name: str = "select" + ) -> node.Node: + """Creates a selector node. The sole job of this is to select just the specified columns. + Note this is a utility function that's only called + + :param upstream_name: Name of the upstream dataframe node + :param columns: Columns to select + :param node_namespace: Namespace of the node + :param node_name: Name of the node to create + :return: + """ + + def new_callable(**kwargs) -> DataFrame: + return kwargs[upstream_name].select(*columns) + + return node.Node( + name=node_name, + typ=DataFrame, + callabl=new_callable, + input_types={upstream_name: DataFrame}, + ) + + def _validate_dataframe_subdag_parameter(self, nodes: List[node.Node], fn_name: str): + all_upstream_dataframe_nodes = _identify_upstream_dataframe_nodes(nodes) + initial_schema = set(self.initial_schema) if self.initial_schema is not None else set() + candidates_for_upstream_dataframe = set(all_upstream_dataframe_nodes) - set(initial_schema) + if ( + len(candidates_for_upstream_dataframe) > 1 + or self.dataframe_subdag_param is None + and len(candidates_for_upstream_dataframe) > 0 + ): + raise ValueError( + f"We found multiple upstream dataframe parameters for function: {fn_name} decorated with " + f"@with_columns. You specified pass_dataframe_as={self.dataframe_subdag_param} as the upstream " + f"dataframe parameter, which means that your subdag must have exactly {0 if self.dataframe_subdag_param is None else 1} " + f"upstream dataframe parameters. Instead, we found the following upstream dataframe parameters: {candidates_for_upstream_dataframe}" + ) + if self.dataframe_subdag_param is not None: + if len(candidates_for_upstream_dataframe) == 0: + raise ValueError( + f"You specified your set of UDFs to use upstream dataframe parameter: {self.dataframe_subdag_param} " + f"for function: {fn_name} decorated with `with_columns`, but we could not find " + "that parameter as a dependency of any of the nodes. Note that that dependency " + "must be a pyspark dataframe. If you wish, instead, to supply an initial set of " + "columns for the upstream dataframe and refer to those columns directly within " + "your UDFs, please use columns_to_pass instead of pass_dataframe_as." + ) + (upstream_dependency,) = list(candidates_for_upstream_dataframe) + if upstream_dependency != self.dataframe_subdag_param: + raise ValueError( + f"You specified your set of UDFs to use upstream dataframe parameter: {self.dataframe_subdag_param} " + f"for function: {fn_name} decorated with `with_columns`, but we found that parameter " + f"as a dependency of a node, but it was not the same as the parameter you specified. " + f"Instead, we found: {upstream_dependency}." + ) + + def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node]: + """Generates nodes in the with_columns groups. This does the following: + + 1. Collects all the nodes from the subdag functions + 2. Prunes them to only include the ones that are upstream from the select columns + 3. Sorts them topologically + 4. Creates a new node for each one, injecting the dataframe parameter into the first one + 5. Creates a new node for the final one, injecting the last node into that one + 6. Returns the list of nodes + + :param fn: Function to generate from + :param config: Config to use for generating/collecting nodes + :return: List of nodes that this function produces + """ + namespace = fn.__name__ if self.namespace is None else self.namespace + initial_nodes = subdag.collect_nodes(config, self.subdag_functions) + transformed_nodes = with_columns._prep_nodes(initial_nodes) + self._validate_dataframe_subdag_parameter(transformed_nodes, fn.__qualname__) + pruned_nodes = prune_nodes(transformed_nodes, self.select) + if len(pruned_nodes) == 0: + raise ValueError( + f"No nodes found upstream from select columns: {self.select} for function: " + f"{fn.__qualname__}" + ) + sorted_initial_nodes = graph_functions.topologically_sort_nodes(pruned_nodes) + output_nodes = [] + inject_parameter = _derive_first_dataframe_parameter_from_fn(fn) + current_dataframe_node = inject_parameter + # Columns that it is dependent on could be from the group of transforms created + columns_produced_within_mapgroup = {node_.name for node_ in pruned_nodes} + columns_passed_in_from_dataframe = ( + set(self.initial_schema) if self.initial_schema is not None else [] + ) + # Or from the dataframe passed in... + for node_ in sorted_initial_nodes: + # dependent columns are broken into two sets: + # 1. Those that come from the group of transforms + dependent_columns_in_mapgroup = { + column for column in node_.input_types if column in columns_produced_within_mapgroup + } + # 2. Those that come from the dataframe + dependent_columns_in_dataframe = { + column for column in node_.input_types if column in columns_passed_in_from_dataframe + } + # In the case that we are using pyspark UDFs + if require_columns.is_decorated_pyspark_udf(node_): + sparkified = require_columns.sparkify_node( + node_, + current_dataframe_node, + inject_parameter, + self.dataframe_subdag_param, + dependent_columns_in_mapgroup, + dependent_columns_in_dataframe, + ) + # otherwise we're using pandas/primitive UDFs + else: + sparkified = sparkify_node_with_udf( + node_, + current_dataframe_node, + inject_parameter, + self.dataframe_subdag_param, + dependent_columns_in_mapgroup, + dependent_columns_in_dataframe, + ) + output_nodes.append(sparkified) + current_dataframe_node = sparkified.name + # We get the final node, which is the function we're using + # and reassign inputs to be the dataframe + if self.mode == "select": + select_columns = ( + self.select if self.select is not None else [item.name for item in output_nodes] + ) + select_node = with_columns.create_selector_node( + upstream_name=current_dataframe_node, columns=select_columns, node_name="_select" + ) + output_nodes.append(select_node) + current_dataframe_node = select_node.name + output_nodes = subdag.add_namespace(output_nodes, namespace) + final_node = node.Node.from_fn(fn).reassign_input_names( + {inject_parameter: assign_namespace(current_dataframe_node, namespace)} + ) + return output_nodes + [final_node] + + def validate(self, fn: Callable): + _derive_first_dataframe_parameter_from_fn(fn) + + +@dataclasses.dataclass +class SparkDataFrameDataLoader(DataLoader): + """Base class for data loaders that load pyspark dataframes. + We are not yet including data savers, but that will be added to this most likely.. + """ + + spark: SparkSession + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DataFrame] + + @abc.abstractmethod + def load_data(self, type_: Type[DataFrame]) -> Tuple[ps.DataFrame, Dict[str, Any]]: + pass + + +@dataclasses.dataclass +class CSVDataLoader(SparkDataFrameDataLoader): + path: str # It supports multiple but for now we're going to have a single one + # We can always make that a list of strings, or make a multiple reader (.multicsv) + header: bool = True + sep: str = "," + + def load_data(self, type_: Type[DataFrame]) -> Tuple[ps.DataFrame, Dict[str, Any]]: + return ( + self.spark.read.csv(self.path, header=self.header, sep=self.sep, inferSchema=True), + utils.get_file_metadata(self.path), + ) + + @classmethod + def name(cls) -> str: + return "csv" + + +@dataclasses.dataclass +class ParquetDataLoader(SparkDataFrameDataLoader): + path: str # It supports multiple but for now we're going to have a single one + + # We can always make that a list of strings, or make a multiple reader (.multicsv) + + def load_data(self, type_: Type[DataFrame]) -> Tuple[ps.DataFrame, Dict[str, Any]]: + return self.spark.read.parquet(self.path), utils.get_file_metadata(self.path) + + @classmethod + def name(cls) -> str: + return "parquet" + + +def register_data_loaders(): + """Function to register the data loaders for this extension.""" + for loader in [CSVDataLoader, ParquetDataLoader]: + registry.register_adapter(loader) + + +register_data_loaders() diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index c5cbffef3..1d54be93b 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -248,12 +248,11 @@ def collect_functions( ) return out - def _collect_nodes(self, original_config: Dict[str, Any]): - combined_config = dict(original_config, **self.config) + @staticmethod + def collect_nodes(config: Dict[str, Any], subdag_functions: List[Callable]) -> List[node.Node]: nodes = [] - for fn in self.subdag_functions: - for node_ in base.resolve_nodes(fn, combined_config): - # nodes.append(node_) + for fn in subdag_functions: + for node_ in base.resolve_nodes(fn, config): nodes.append(node_.copy_with(tags={**node_.tags, **NON_FINAL_TAGS})) return nodes @@ -302,25 +301,32 @@ def _create_additional_static_nodes( ) return out - def _add_namespace(self, nodes: List[node.Node], namespace: str) -> List[node.Node]: + @staticmethod + def add_namespace( + nodes: List[node.Node], + namespace: str, + inputs: Dict[str, Any] = None, + config: Dict[str, Any] = None, + ) -> List[node.Node]: """Utility function to add a namespace to nodes. :param nodes: :return: """ - # already_namespaced_nodes = [] + inputs = inputs if inputs is not None else {} + config = config if config is not None else {} new_nodes = [] new_name_map = {} # First pass we validate + collect names so we can alter dependencies for node_ in nodes: new_name = assign_namespace(node_.name, namespace) new_name_map[node_.name] = new_name - for dep, value in self.inputs.items(): + for dep, value in inputs.items(): # We create nodes for both namespace assignment and source assignment # Why? Cause we need unique parameter names, and with source() some can share params new_name_map[dep] = assign_namespace(dep, namespace) - for dep, value in self.config.items(): + for dep, value in config.items(): new_name_map[dep] = assign_namespace(dep, namespace) # Reassign sources @@ -398,12 +404,13 @@ def _derive_name(self, fn: Callable) -> str: def generate_nodes(self, fn: Callable, configuration: Dict[str, Any]) -> Collection[node.Node]: # Resolve all nodes from passed in functions - nodes = self._collect_nodes(original_config=configuration) + resolved_config = dict(configuration, **self.config) + nodes = self.collect_nodes(config=resolved_config, subdag_functions=self.subdag_functions) # Derive the namespace under which all these nodes will live namespace = self._derive_namespace(fn) final_node_name = self._derive_name(fn) # Rename them all to have the right namespace - nodes = self._add_namespace(nodes, namespace) + nodes = self.add_namespace(nodes, namespace, self.inputs, self.config) # Create any static input nodes we need to translate nodes += self._create_additional_static_nodes(nodes, namespace) # Add the final node that does the translation diff --git a/hamilton/htypes.py b/hamilton/htypes.py index e5f70fa9a..613a29c13 100644 --- a/hamilton/htypes.py +++ b/hamilton/htypes.py @@ -45,6 +45,7 @@ def custom_subclass_check(requested_type: Type, param_type: Type): """ # handles case when someone is using primitives and generics requested_origin_type = requested_type + param_type, _ = get_type_information(param_type) param_origin_type = param_type has_generic = False if _safe_subclass(requested_type, param_type): diff --git a/hamilton/node.py b/hamilton/node.py index 3bba90aa6..6d55b4195 100644 --- a/hamilton/node.py +++ b/hamilton/node.py @@ -1,4 +1,5 @@ import inspect +import sys import typing from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union @@ -241,7 +242,9 @@ def from_fn(fn: Callable, name: str = None) -> "Node": """ if name is None: name = fn.__name__ - return_type = typing.get_type_hints(fn).get("return") + # TODO -- remove this when we no longer support 3.8 -- 10/14/2024 + type_hint_kwargs = {} if sys.version_info < (3, 9) else {"include_extras": True} + return_type = typing.get_type_hints(fn, **type_hint_kwargs).get("return") if return_type is None: raise ValueError(f"Missing type hint for return value in function {fn.__qualname__}.") node_source = NodeType.STANDARD @@ -255,7 +258,6 @@ def from_fn(fn: Callable, name: str = None) -> "Node": if typing_inspect.get_origin(hint) == Collect: node_source = NodeType.COLLECT break - module = inspect.getmodule(fn).__name__ return Node( name, @@ -302,3 +304,38 @@ def copy(self, include_refs: bool = True) -> "Node": :return: A copy of the node. """ return self.copy_with(include_refs) + + def reassign_input_names(self, input_names: Dict[str, Any]) -> "Node": + """Reassigns the input names of a node. Useful for applying + a node to a separate input if needed. Note that things can get a + little strange if you have multiple inputs with the same name, so + be careful about how you use this. + + :param input_names: Input name map to reassign + :return: A node with the input names reassigned + """ + + def new_callable(**kwargs) -> Any: + reverse_input_names = {v: k for k, v in input_names.items()} + return self.callable(**{reverse_input_names.get(k, k): v for k, v in kwargs.items()}) + + new_input_types = {input_names.get(k, k): v for k, v in self.input_types.items()} + out = self.copy_with(callabl=new_callable, input_types=new_input_types) + return out + + def transform_output( + self, __transform: Callable[[Dict[str, Any], Any], Any], __output_type: Type[Any] + ) -> "Node": + """Applies a transformation on the output of the node, returning a new node. + Also modifies the type. + + :param __transform: Transformation to apply. This is a function with two arguments: + (a) the kwargs passed to the node, and (b) the output of the node. + :param __output_type: Return type of the transformation + :return: A new node, with the right type/transformation + """ + + def new_callable(**kwargs) -> Any: + return __transform(self.callable(**kwargs), kwargs) + + return self.copy_with(callabl=new_callable, typ=__output_type) diff --git a/tests/resources/spark/__init__.py b/tests/resources/spark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/resources/spark/basic_spark_dag.py b/tests/resources/spark/basic_spark_dag.py new file mode 100644 index 000000000..78b0f66fd --- /dev/null +++ b/tests/resources/spark/basic_spark_dag.py @@ -0,0 +1,78 @@ +import pandas as pd +import pyspark.sql as ps + +from hamilton.experimental import h_spark +from hamilton.function_modifiers import config +from hamilton.htypes import column as _ + +IntSeries = _[pd.Series, int] +FloatSeries = _[pd.Series, float] + + +def a(a_raw: IntSeries) -> IntSeries: + return a_raw + 1 + + +def b(b_raw: IntSeries) -> IntSeries: + return b_raw + 3 + + +def c(c_raw: IntSeries) -> FloatSeries: + return c_raw * 3.5 + + +def a_times_key(a: IntSeries, key: IntSeries) -> IntSeries: + return a * key + + +def b_times_key(b: IntSeries, key: IntSeries) -> IntSeries: + return b * key + + +def a_plus_b_plus_c(a: IntSeries, b: IntSeries, c: FloatSeries) -> FloatSeries: + return a + b + c + + +def df_1(spark_session: ps.SparkSession) -> ps.DataFrame: + df = pd.DataFrame.from_records( + [ + {"key": 1, "a_raw": 1, "b_raw": 2, "c_raw": 1}, + {"key": 2, "a_raw": 4, "b_raw": 5, "c_raw": 2}, + {"key": 3, "a_raw": 7, "b_raw": 8, "c_raw": 3}, + {"key": 4, "a_raw": 10, "b_raw": 11, "c_raw": 4}, + {"key": 5, "a_raw": 13, "b_raw": 14, "c_raw": 5}, + ] + ) + return spark_session.createDataFrame(df) + + +@h_spark.with_columns( + a, + b, + c, + a_times_key, + b_times_key, + a_plus_b_plus_c, + select=["a_times_key", "b_times_key", "a_plus_b_plus_c"], + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], +) +@config.when_not(mode="select") +def processed_df_as_pandas__append(df_1: ps.DataFrame) -> pd.DataFrame: + return df_1.select("a_times_key", "b_times_key", "a_plus_b_plus_c").toPandas() + + +@h_spark.with_columns( + a, + b, + c, + a_times_key, + b_times_key, + a_plus_b_plus_c, + select=["a_times_key", "a_plus_b_plus_c"], + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], + mode="select", +) +@config.when(mode="select") +def processed_df_as_pandas__select(df_1: ps.DataFrame) -> pd.DataFrame: + # This should have two columns + return df_1.toPandas() diff --git a/tests/resources/pyspark_udfs.py b/tests/resources/spark/pyspark_udfs.py similarity index 100% rename from tests/resources/pyspark_udfs.py rename to tests/resources/spark/pyspark_udfs.py diff --git a/tests/resources/spark/spark_dag_external_dependencies.py b/tests/resources/spark/spark_dag_external_dependencies.py new file mode 100644 index 000000000..61b16473f --- /dev/null +++ b/tests/resources/spark/spark_dag_external_dependencies.py @@ -0,0 +1,40 @@ +import pandas as pd +import pyspark.sql as ps + +from hamilton.experimental import h_spark +from hamilton.htypes import column as _ + +IntSeries = _[pd.Series, int] + + +def to_multiply() -> int: + return 2 + + +def a(initial_column: IntSeries, to_add: int = 1) -> IntSeries: + return initial_column + to_add + + +def b(a: IntSeries, to_multiply: int) -> IntSeries: + return a * to_multiply + + +def df_input(spark_session: ps.SparkSession) -> ps.DataFrame: + df = pd.DataFrame.from_records( + [ + {"initial_column": 1}, + {"initial_column": 2}, + {"initial_column": 3}, + {"initial_column": 4}, + ] + ) + return spark_session.createDataFrame(df) + + +@h_spark.with_columns( + a, + b, + columns_to_pass=["initial_column"], +) +def processed_df_as_pandas(df_input: ps.DataFrame) -> pd.DataFrame: + return df_input.select("a", "b").toPandas() diff --git a/tests/resources/spark/spark_dag_mixed_pyspark_pandas_udfs.py b/tests/resources/spark/spark_dag_mixed_pyspark_pandas_udfs.py new file mode 100644 index 000000000..e5809c9e4 --- /dev/null +++ b/tests/resources/spark/spark_dag_mixed_pyspark_pandas_udfs.py @@ -0,0 +1,95 @@ +from typing import Callable, List + +import pandas as pd +import pyspark.sql as ps + +from hamilton.experimental import h_spark +from hamilton.function_modifiers import parameterize, value +from hamilton.htypes import column as _ + +IntSeries = _[pd.Series, int] +FloatSeries = _[pd.Series, float] + + +def to_add() -> int: + return 1 + + +def spark_session() -> ps.SparkSession: + spark = ( + ps.SparkSession.builder.master("local") + .appName("spark session") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() + ) + return spark + + +def _module(user_controls_initial_dataframe: bool) -> List[Callable]: + out = [] + if user_controls_initial_dataframe: + + @parameterize( + a_raw={"col": value("a_raw")}, + b_raw={"col": value("b_raw")}, + c_raw={"col": value("c_raw")}, + key={"col": value("key")}, + ) + def raw_col(external_dataframe: ps.DataFrame, col: str) -> ps.Column: + return external_dataframe[col] + + out.append(raw_col) + + def a(a_raw: ps.DataFrame, to_add: int) -> ps.DataFrame: + return a_raw.withColumn("a", a_raw.a_raw + to_add) + + def b(b_raw: ps.DataFrame, b_add: int = 3) -> ps.Column: + return b_raw["b_raw"] + b_add + + def c(c_raw: IntSeries) -> FloatSeries: + return c_raw * 3.5 + + @h_spark.require_columns("a", "key") + def a_times_key(a_key: ps.DataFrame, identity_multiplier: int = 1) -> ps.Column: + return a_key.a * a_key.key * identity_multiplier + + def b_times_key(b: IntSeries, key: IntSeries) -> IntSeries: + return b * key + + @h_spark.require_columns("a", "b", "c") + def a_plus_b_plus_c(a_b_c: ps.DataFrame) -> ps.Column: + return a_b_c.a + a_b_c.b + a_b_c.c + + out.extend([a, b, c, a_times_key, b_times_key, a_plus_b_plus_c]) + return out + + +def df_1(spark_session: ps.SparkSession) -> ps.DataFrame: + df = pd.DataFrame.from_records( + [ + {"key": 1, "a_raw": 1, "b_raw": 2, "c_raw": 1}, + {"key": 2, "a_raw": 4, "b_raw": 5, "c_raw": 2}, + {"key": 3, "a_raw": 7, "b_raw": 8, "c_raw": 3}, + {"key": 4, "a_raw": 10, "b_raw": 11, "c_raw": 4}, + {"key": 5, "a_raw": 13, "b_raw": 14, "c_raw": 5}, + ] + ) + return spark_session.createDataFrame(df) + + +@h_spark.with_columns( + *_module(False), + select=["a_times_key", "b_times_key", "a_plus_b_plus_c"], + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], +) +def processed_df_as_pandas(df_1: ps.DataFrame) -> pd.DataFrame: + return df_1.select("a_times_key", "b_times_key", "a_plus_b_plus_c").toPandas() + + +@h_spark.with_columns( + *_module(True), + select=["a_times_key", "b_times_key", "a_plus_b_plus_c"], + pass_dataframe_as="external_dataframe", +) +def processed_df_as_pandas_dataframe_with_injected_dataframe(df_1: ps.DataFrame) -> pd.DataFrame: + return df_1.select("a_times_key", "b_times_key", "a_plus_b_plus_c").toPandas() diff --git a/tests/resources/spark/spark_dag_multiple_with_columns.py b/tests/resources/spark/spark_dag_multiple_with_columns.py new file mode 100644 index 000000000..8d00aefab --- /dev/null +++ b/tests/resources/spark/spark_dag_multiple_with_columns.py @@ -0,0 +1,114 @@ +import pandas as pd +import pyspark.sql as ps + +from hamilton.experimental import h_spark +from hamilton.htypes import column as _ + +IntSeries = _[pd.Series, int] +FloatSeries = _[pd.Series, float] + + +def a(a_raw: IntSeries) -> IntSeries: + return a_raw + 1 + + +def b(b_raw: IntSeries) -> IntSeries: + return b_raw + 3 + + +def c(c_raw: IntSeries) -> FloatSeries: + return c_raw * 3.5 + + +def a_times_key(a: IntSeries, key: IntSeries) -> IntSeries: + return a * key + + +def b_times_key(b: IntSeries, key: IntSeries) -> IntSeries: + return b * key + + +def a_plus_b_plus_c(a: IntSeries, b: IntSeries, c: FloatSeries) -> FloatSeries: + return a + b + c + + +def const_1() -> float: + return 4.3 + + +# Placing these functions here so we don't try to read the DAG +# This tests the mixing of different types, which is *only* allowed +# inside the with_columns subdag, and not yet allowed within Hamilton +# as hamilton doesn't know that they will compile to the same nodes + + +def _df_2_modules(): + def d(d_raw: IntSeries) -> IntSeries: + return d_raw + 5 + + def e(e_raw: int, d: int, const_1: float) -> float: + return e_raw + d + const_1 + + def f(f_raw: int) -> float: + return f_raw * 3.5 + + def multiply_d_e_f_key( + d: IntSeries, e: FloatSeries, f: FloatSeries, key: IntSeries + ) -> FloatSeries: + return d * e * f * key + + return [d, e, f, multiply_d_e_f_key] + + +def df_1(spark_session: ps.SparkSession) -> ps.DataFrame: + df = pd.DataFrame.from_records( + [ + {"key": 1, "a_raw": 1, "b_raw": 2, "c_raw": 1}, + {"key": 2, "a_raw": 4, "b_raw": 5, "c_raw": 2}, + {"key": 3, "a_raw": 7, "b_raw": 8, "c_raw": 3}, + {"key": 4, "a_raw": 10, "b_raw": 11, "c_raw": 4}, + {"key": 5, "a_raw": 13, "b_raw": 14, "c_raw": 5}, + ] + ) + return spark_session.createDataFrame(df) + + +@h_spark.with_columns( + # TODO -- have a pool (module, rather than function) that we can select *from* + # Or just select in the processed_df? + a, + b, + c, + a_times_key, + b_times_key, + a_plus_b_plus_c, + select=["a_times_key", "b_times_key", "a_plus_b_plus_c"], + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], +) +def processed_df_1(df_1: ps.DataFrame) -> ps.DataFrame: + return df_1.select("key", "a_times_key", "b_times_key", "a_plus_b_plus_c") + + +def df_2(spark_session: ps.SparkSession) -> ps.DataFrame: + df = pd.DataFrame.from_records( + [ + {"key": 1, "d_raw": 1, "e_raw": 2, "f_raw": 5}, + {"key": 2, "d_raw": 4, "e_raw": 5, "f_raw": 10}, + {"key": 3, "d_raw": 7, "e_raw": 8, "f_raw": 15}, + {"key": 4, "d_raw": 10, "e_raw": 11, "f_raw": 20}, + ] + ) + return spark_session.createDataFrame(df) + + +@h_spark.with_columns( + *_df_2_modules(), + select=["multiply_d_e_f_key", "d", "e", "f"], + columns_to_pass=["d_raw", "e_raw", "f_raw", "key"], +) +def processed_df_2_joined_df_1(df_2: ps.DataFrame, processed_df_1: ps.DataFrame) -> ps.DataFrame: + return df_2.join(processed_df_1, processed_df_1["key"] == df_2["key"], "inner").drop(df_2.key) + + +def final(processed_df_2_joined_df_1: ps.DataFrame) -> pd.DataFrame: + return processed_df_2_joined_df_1.toPandas() diff --git a/tests/resources/spark/spark_dag_pyspark_udfs.py b/tests/resources/spark/spark_dag_pyspark_udfs.py new file mode 100644 index 000000000..cfa54c560 --- /dev/null +++ b/tests/resources/spark/spark_dag_pyspark_udfs.py @@ -0,0 +1,92 @@ +from typing import Callable, List + +import pandas as pd +import pyspark.sql as ps + +from hamilton.experimental import h_spark +from hamilton.function_modifiers import parameterize, value + + +def to_add() -> int: + return 1 + + +def spark_session() -> ps.SparkSession: + spark = ( + ps.SparkSession.builder.master("local") + .appName("spark session") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() + ) + return spark + + +def _module(user_controls_initial_dataframe: bool) -> List[Callable]: + out = [] + if user_controls_initial_dataframe: + + @parameterize( + a_raw={"col": value("a_raw")}, + b_raw={"col": value("b_raw")}, + c_raw={"col": value("c_raw")}, + key={"col": value("key")}, + ) + def raw_col(external_dataframe: ps.DataFrame, col: str) -> ps.Column: + return external_dataframe[col] + + out.append(raw_col) + + def a(a_raw: ps.DataFrame, to_add: int) -> ps.DataFrame: + return a_raw.withColumn("a", a_raw.a_raw + to_add) + + def b(b_raw: ps.DataFrame, b_add: int = 3) -> ps.Column: + return b_raw["b_raw"] + b_add + + def c(c_raw: ps.DataFrame) -> ps.Column: + return c_raw.c_raw * 3.5 + + @h_spark.require_columns("a", "key") + def a_times_key(a_key: ps.DataFrame, identity_multiplier: int = 1) -> ps.Column: + return a_key.a * a_key.key * identity_multiplier + + @h_spark.require_columns("b", "key") + def b_times_key(b_key: ps.DataFrame) -> ps.Column: + return b_key.b * b_key.key + + @h_spark.require_columns("a", "b", "c") + def a_plus_b_plus_c(a_b_c: ps.DataFrame) -> ps.Column: + return a_b_c.a + a_b_c.b + a_b_c.c + + out.extend([a, b, c, a_times_key, b_times_key, a_plus_b_plus_c]) + return out + + +def df_1(spark_session: ps.SparkSession) -> ps.DataFrame: + df = pd.DataFrame.from_records( + [ + {"key": 1, "a_raw": 1, "b_raw": 2, "c_raw": 1}, + {"key": 2, "a_raw": 4, "b_raw": 5, "c_raw": 2}, + {"key": 3, "a_raw": 7, "b_raw": 8, "c_raw": 3}, + {"key": 4, "a_raw": 10, "b_raw": 11, "c_raw": 4}, + {"key": 5, "a_raw": 13, "b_raw": 14, "c_raw": 5}, + ] + ) + return spark_session.createDataFrame(df) + + +@h_spark.with_columns( + *_module(False), + select=["a_times_key", "b_times_key", "a_plus_b_plus_c"], + columns_to_pass=["a_raw", "b_raw", "c_raw", "key"], +) +def processed_df_as_pandas(df_1: ps.DataFrame) -> pd.DataFrame: + return df_1.select("a_times_key", "b_times_key", "a_plus_b_plus_c").toPandas() + + +@h_spark.with_columns( + *_module(True), + select=["a_times_key", "b_times_key", "a_plus_b_plus_c"], + pass_dataframe_as="external_dataframe", +) +def processed_df_as_pandas_with_injected_dataframe(df_1: ps.DataFrame) -> pd.DataFrame: + return df_1.select("a_times_key", "b_times_key", "a_plus_b_plus_c").toPandas() diff --git a/tests/test_type_utils.py b/tests/test_type_utils.py index 4670e9104..4e449d4a0 100644 --- a/tests/test_type_utils.py +++ b/tests/test_type_utils.py @@ -57,6 +57,8 @@ class Y(X): (typing.Dict, collections.Counter, True), # These are not subclasses of each other, see issue 42 (typing.FrozenSet[int], typing.Set[int], False), + (htypes.column[pd.Series, int], pd.Series, True), + (htypes.column[pd.Series, int], int, False), ], ) def test_custom_subclass_check(param_type, requested_type, expected):