diff --git a/examples/LLM_Workflows/pdf_summarizer/README.md b/examples/LLM_Workflows/pdf_summarizer/README.md
index 514cd1024..2236983c5 100644
--- a/examples/LLM_Workflows/pdf_summarizer/README.md
+++ b/examples/LLM_Workflows/pdf_summarizer/README.md
@@ -35,3 +35,7 @@ or you can do `docker compose logs -f` to tail the logs (ctrl+c to stop tailing
3. Uncomment dagworks-sdk in `requirements.txt`.
4. Uncomment the lines in server.py to replace `sync_dr` with the DAGWorks Driver.
5. Rebuild the docker images `docker compose up -d --build`.
+
+# Running on Spark!
+Yes, that's right, you can also run the exact same code on spark! It's just a oneline
+code change. See the [run_on_spark README](run_on_spark/README.md) for more details.
diff --git a/examples/LLM_Workflows/pdf_summarizer/backend/server.py b/examples/LLM_Workflows/pdf_summarizer/backend/server.py
index 9974b5127..4bb5d49ce 100644
--- a/examples/LLM_Workflows/pdf_summarizer/backend/server.py
+++ b/examples/LLM_Workflows/pdf_summarizer/backend/server.py
@@ -12,9 +12,7 @@
# define constants for Hamilton driver
-driver_config = dict(
- file_type="pdf",
-)
+driver_config = dict(file_type="pdf")
# instantiate the Hamilton driver; it will power all API endpoints
# async driver for use with async functions
diff --git a/examples/LLM_Workflows/pdf_summarizer/backend/summarization.py b/examples/LLM_Workflows/pdf_summarizer/backend/summarization.py
index 7d2c0dce2..74cff2d3b 100644
--- a/examples/LLM_Workflows/pdf_summarizer/backend/summarization.py
+++ b/examples/LLM_Workflows/pdf_summarizer/backend/summarization.py
@@ -1,6 +1,6 @@
import concurrent
import tempfile
-from typing import Generator
+from typing import Generator, Union
import openai
import tiktoken
@@ -26,10 +26,10 @@ def summarize_text_from_summaries_prompt(content_type: str = "an academic paper"
@config.when(file_type="pdf")
-def raw_text(pdf_source: str | bytes | tempfile.SpooledTemporaryFile) -> str:
+def raw_text(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) -> str:
"""Takes a filepath to a PDF and returns a string of the PDF's contents
- :param pdf_source: Series of filepaths to PDFs
- :return: Series of strings of the PDFs' contents
+ :param pdf_source: the path, or the temporary file, to the PDF.
+ :return: the text of the PDF.
"""
reader = PdfReader(pdf_source)
_pdf_text = ""
@@ -67,10 +67,10 @@ def _create_chunks(text: str, n: int, tokenizer: tiktoken.Encoding) -> Generator
def chunked_text(
- raw_text: str, max_token_length: int = 1500, tokenizer_encoding: str = "cl100k_base"
+ raw_text: str, tokenizer_encoding: str = "cl100k_base", max_token_length: int = 1500
) -> list[str]:
"""Chunks the pdf text into smaller chunks of size max_token_length.
- :param pdf_text: the Series of individual pdf texts to chunk.
+ :param raw_text: the Series of individual pdf texts to chunk.
:param max_token_length: the maximum length of tokens in each chunk.
:param tokenizer_encoding: the encoding to use for the tokenizer.
:return: Series of chunked pdf text. Each element is a list of chunks.
@@ -102,7 +102,7 @@ def summarized_chunks(
"""Summarizes a series of chunks of text.
Note: this takes the first result from the top_n_related_articles series and summarizes it. This is because
the top_n_related_articles series is sorted by relatedness, so the first result is the most related.
- :param top_n_related_articles: series with each entry being a list of chunks of text for an article.
+ :param chunked_text: a list of chunks of text for an article.
:param summarize_chunk_of_text_prompt: the prompt to use to summarize each chunk of text.
:param openai_gpt_model: the openai gpt model to use.
:return: a single string of each chunk of text summarized, concatenated together.
@@ -125,12 +125,14 @@ def summarized_chunks(
def prompt_and_text_content(
- summarize_text_from_summaries_prompt: str, user_query: str, summarized_chunks: str
+ summarized_chunks: str,
+ summarize_text_from_summaries_prompt: str,
+ user_query: str,
) -> str:
"""Creates the prompt for summarizing the text from the summarized chunks of the pdf.
+ :param summarized_chunks: a long string of chunked summaries of a file.
:param summarize_text_from_summaries_prompt: the template to use to summarize the chunks.
:param user_query: the original user query.
- :param summarized_chunks: a long string of chunked summaries of a file.
:return: the prompt to use to summarize the chunks.
"""
return summarize_text_from_summaries_prompt.format(query=user_query, results=summarized_chunks)
diff --git a/examples/LLM_Workflows/pdf_summarizer/run_on_spark/README.md b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/README.md
new file mode 100644
index 000000000..67e7eaedd
--- /dev/null
+++ b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/README.md
@@ -0,0 +1,65 @@
+# PDF Summarizer on Spark
+
+Here we show how you can run the same Hamilton dataflow, that we defined in the backend
+folder, on Spark. This is useful if you want to run the same dataflow on a larger dataset,
+or have to run it on a cluster. Importantly this means you don't have to rewrite your
+code, or have to change where/how you develop!
+
+![Summarization dataflow](spark_summarization.dot.png)
+
+# File organization
+ - `summarization.py` this should be a carbon copy of the one in the backend folder.
+ - `run.py` this contains the code to create a spark job and run the summarization code.
+
+# How this works
+We take the dataflow defined by `summarization.py` and execute it as a bunch
+of row based UDFs on spark. The magic to do this happens in the Hamilton PySparkUDFGraphAdapter.
+
+The premise is that there is a central dataframe
+that contains a column that maps to the required input. Which in this example
+is `pdf_source`. You can request whatever intermediate outputs as columns, which
+in this example we do with `["raw_text", "chunked_text", "summarized_text"]`.
+
+## Running the code
+
+1. Make sure you have the right dependencies installed. You can do this by running
+`pip install -r requirements.txt`.
+2. Download some PDFs, and then update `run.py` with the paths to them.
+3. Then you can run the code with `python run.py`. Be sure to have your OPENAI_API_KEY in the
+environment.
+
+# Sharing `summarization.py` in real life
+Here in this example, we just copied summarization.py to share the code. In real life
+you would most likely:
+
+1. create a python package with your dataflows and publish things that way.
+2. or, in lieu of publishing a package, you share code via repository and augment the python path/
+zip the code up to share it between the fastAPI service and spark.
+
+
+
+
+# Errors you might encounter:
+## Fork error on mac
+If you are running spark on a mac, you might get the following error:
+
+```
+objc[95025]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called.
+objc[95025]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
+```
+Export the following environment variable to fix it before running the code:
+
+```bash
+export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
+```
+
+## Pyspark error "got multiple values for argument"
+You should not get this error, but you might if you adjust the code.
+
+E.g.
+```python
+TypeError: prompt_and_text_content() got multiple values for argument 'summarize_text_from_summaries_prompt'
+```
+Solution -- ensure that what ends up being columns in a dataframe are the leftmost
+arguments to each function, and not after any "primitive" arguments. This is because we
+bind primitive functions with a kwargs call, but we pass in columns as positional arguments.
diff --git a/examples/LLM_Workflows/pdf_summarizer/run_on_spark/requirements.txt b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/requirements.txt
new file mode 100644
index 000000000..7edb5f09b
--- /dev/null
+++ b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/requirements.txt
@@ -0,0 +1,8 @@
+openai
+PyPDF2
+pyspark
+sf-hamilton[visualization]
+tenacity
+tiktoken
+tqdm
+# dagworks-sdk>=0.0.14 # uncomment to use DAGWorks driver.
diff --git a/examples/LLM_Workflows/pdf_summarizer/run_on_spark/run.ipynb b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/run.ipynb
new file mode 100644
index 000000000..eb9bf0380
--- /dev/null
+++ b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/run.ipynb
@@ -0,0 +1,470 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "source": [
+ "# Notebook showing how to run PDF summarizer on Spark\n",
+ "In this notebook we'll walk through what's in `run.py`, which shows how one\n",
+ "can setup a spark job to run the PDF summarizer dataflow defined in `summarization.py`.\n",
+ "\n",
+ "Note: if you're on a mac you might need to do the following in your environment as you start jupyter/this kernel:\n",
+ "> OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES\n",
+ "\n",
+ "For your OPENAI_API_KEY you can put it in the enviornment as well, or modify this notebook directly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-19T20:29:26.741436Z",
+ "start_time": "2023-08-19T20:29:24.782064Z"
+ },
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# imports\n",
+ "import os\n",
+ "import pandas as pd\n",
+ "import summarization\n",
+ "from pyspark.sql import SparkSession\n",
+ "\n",
+ "from hamilton import driver, log_setup\n",
+ "from hamilton.experimental import h_spark"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-19T20:29:43.315228Z",
+ "start_time": "2023-08-19T20:29:38.212840Z"
+ },
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# more setup for spark, etc.\n",
+ "openai_api_key = os.environ.get(\"OPENAI_API_KEY\")\n",
+ "log_setup.setup_logging(log_level=log_setup.LOG_LEVELS[\"INFO\"])\n",
+ "# create the SparkSession -- note in real life, you'd adjust the number of executors to control parallelism.\n",
+ "spark = SparkSession.builder.config(\n",
+ " \"spark.executorEnv.OPENAI_API_KEY\", openai_api_key\n",
+ "#).config( # you might need to following in case things don't work for you.\n",
+ "# \"spark.sql.warehouse.dir\", \"~/temp/dwh\"\n",
+ "#).master(\n",
+ "# \"local[1]\" # Change this in real life.\n",
+ ").getOrCreate()\n",
+ "spark.sparkContext.setLogLevel(\"info\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-19T20:29:43.320253Z",
+ "start_time": "2023-08-19T20:29:43.317533Z"
+ },
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Set up specifics for this example\n",
+ "openai_gpt_model = \"gpt-3.5-turbo-0613\"\n",
+ "content_type = \"Scientific article\"\n",
+ "user_query = \"Can you ELI5 the paper?\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "start_time": "2023-08-19T20:30:25.359581Z"
+ },
+ "collapsed": false,
+ "is_executing": true,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Create the input dataframe\n",
+ "# replace this with SQL or however you'd get the data you need in.\n",
+ "pandas_df = pd.DataFrame(\n",
+ " # TODO: update this to point to a PDF or two.\n",
+ " {\"pdf_source\": [\"a/path/to/a/PDF/CDMS2022-hamilton-paper.pdf\"]}\n",
+ ")\n",
+ "df = spark.createDataFrame(pandas_df)\n",
+ "df.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-19T20:29:54.797884Z",
+ "start_time": "2023-08-19T20:29:54.783823Z"
+ },
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[WARNING] 2023-08-19 14:05:57,410 hamilton.telemetry(127): Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/dagworks-inc/hamilton#usage-analytics--data-privacy for details.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Create the driver\n",
+ "modules = [summarization]\n",
+ "driver_config = dict(file_type=\"pdf\")\n",
+ "# create the Hamilton driver\n",
+ "adapter = h_spark.PySparkUDFGraphAdapter()\n",
+ "dr = driver.Driver(driver_config, *modules, adapter=adapter) # can pass in multiple modules\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-19T20:31:02.455114Z",
+ "start_time": "2023-08-19T20:31:02.435225Z"
+ },
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# create inputs to the UDFs - this needs to be column_name -> spark dataframe.\n",
+ "execute_inputs = {col: df for col in df.columns}\n",
+ "# add in any other scalar inputs/values/objects needed by the UDFs\n",
+ "execute_inputs.update(\n",
+ " dict(\n",
+ " openai_gpt_model=openai_gpt_model,\n",
+ " content_type=content_type,\n",
+ " user_query=user_query,\n",
+ " )\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-19T20:31:04.846274Z",
+ "start_time": "2023-08-19T20:31:04.841799Z"
+ },
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# tell Hamilton what columns need to be appended to the dataframe.\n",
+ "cols_to_append = [\n",
+ " \"raw_text\",\n",
+ " \"chunked_text\",\n",
+ " \"summarized_text\",\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-19T20:31:07.500434Z",
+ "start_time": "2023-08-19T20:31:07.478062Z"
+ },
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# visualize execution of what is going to be appended\n",
+ "dr.visualize_execution(\n",
+ " cols_to_append, None, None, inputs=execute_inputs\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-19T20:25:26.570950Z",
+ "start_time": "2023-08-19T20:25:26.563166Z"
+ },
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# tell Hamilton to tell Spark what to do\n",
+ "df = dr.execute(cols_to_append, inputs=execute_inputs)\n",
+ "df.explain()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "df.show()\n",
+ "# you can also save the dataframe as a json file, parquet, etc.\n",
+ "# df.write.json(\"processed_pdfs\")\n",
+ "spark.stop()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/LLM_Workflows/pdf_summarizer/run_on_spark/run.py b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/run.py
new file mode 100644
index 000000000..4b4fe1339
--- /dev/null
+++ b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/run.py
@@ -0,0 +1,73 @@
+"""Spark driver and Hamilton driver code."""
+
+import pandas as pd
+import summarization
+from pyspark.sql import SparkSession
+
+from hamilton import driver, log_setup
+from hamilton.experimental import h_spark
+
+
+def my_spark_job(spark: SparkSession, openai_gpt_model: str, content_type: str, user_query: str):
+ """Template for a Spark job that uses Hamilton for their featuring engineering, i.e. any map, operations.
+
+ :param spark: the SparkSession
+ :param openai_gpt_model: the model to use for summarization
+ :param content_type: the content type of the document to summarize
+ :param user_query: the user query to use for summarization
+ """
+ # replace this with SQL or however you'd get the data you need in.
+ pandas_df = pd.DataFrame(
+ # TODO: update this to point to a PDF or two.
+ {"pdf_source": ["a/path/to/a/PDF/CDMS2022-hamilton-paper.pdf"]}
+ )
+ df = spark.createDataFrame(pandas_df)
+ # get the modules that contain the UDFs
+ modules = [summarization]
+ driver_config = dict(file_type="pdf")
+ # create the Hamilton driver
+ adapter = h_spark.PySparkUDFGraphAdapter()
+ dr = driver.Driver(driver_config, *modules, adapter=adapter) # can pass in multiple modules
+ # create inputs to the UDFs - this needs to be column_name -> spark dataframe.
+ execute_inputs = {col: df for col in df.columns}
+ # add in any other scalar inputs/values/objects needed by the UDFs
+ execute_inputs.update(
+ dict(
+ openai_gpt_model=openai_gpt_model,
+ content_type=content_type,
+ user_query=user_query,
+ )
+ )
+ # tell Hamilton what columns need to be appended to the dataframe.
+ cols_to_append = [
+ "raw_text",
+ "chunked_text",
+ "summarized_text",
+ ]
+ # visualize execution of what is going to be appended
+ dr.visualize_execution(
+ cols_to_append, "./spark_summarization.dot", {"format": "png"}, inputs=execute_inputs
+ )
+ # tell Hamilton to tell Spark what to do
+ df = dr.execute(cols_to_append, inputs=execute_inputs)
+ df.explain()
+ return df
+
+
+if __name__ == "__main__":
+ import os
+
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
+ log_setup.setup_logging(log_level=log_setup.LOG_LEVELS["INFO"])
+ # create the SparkSession -- note in real life, you'd adjust the number of executors to control parallelism.
+ spark = SparkSession.builder.config(
+ "spark.executorEnv.OPENAI_API_KEY", openai_api_key
+ ).getOrCreate()
+ spark.sparkContext.setLogLevel("info")
+ # run the job
+ _df = my_spark_job(spark, "gpt-3.5-turbo-0613", "Scientific article", "Can you ELI5 the paper?")
+ # show the dataframe & thus make spark compute something
+ _df.show()
+ # you can also save the dataframe as a json file, parquet, etc.
+ # _df.write.json("processed_pdfs")
+ spark.stop()
diff --git a/examples/LLM_Workflows/pdf_summarizer/run_on_spark/summarization.py b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/summarization.py
new file mode 100644
index 000000000..c56a0c94f
--- /dev/null
+++ b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/summarization.py
@@ -0,0 +1,182 @@
+import concurrent
+import tempfile
+from typing import Generator, Union
+
+import openai
+import tiktoken
+from PyPDF2 import PdfReader
+from tenacity import retry, stop_after_attempt, wait_random_exponential
+from tqdm import tqdm
+
+from hamilton.function_modifiers import config
+
+"""
+This module is a carbon copy of the module in the backend. In real life you'd
+set up some package or structure that would allow you to share code between the
+two. However this is just an example, and rather than set up a whole package, or play
+with sys.path, we thought this would be simpler.
+"""
+
+
+def summarize_chunk_of_text_prompt(content_type: str = "an academic paper") -> str:
+ """Base prompt for summarizing chunks of text."""
+ return f"Summarize this text from {content_type}. Extract any key points with reasoning.\n\nContent:"
+
+
+def summarize_text_from_summaries_prompt(content_type: str = "an academic paper") -> str:
+ """Prompt for summarizing a paper from a list of summaries."""
+ return f"""Write a summary collated from this collection of key points extracted from {content_type}.
+ The summary should highlight the core argument, conclusions and evidence, and answer the user's query.
+ User query: {{query}}
+ The summary should be structured in bulleted lists following the headings Core Argument, Evidence, and Conclusions.
+ Key points:\n{{results}}\nSummary:\n"""
+
+
+@config.when(file_type="pdf")
+def raw_text(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) -> str:
+ """Takes a filepath to a PDF and returns a string of the PDF's contents
+ :param pdf_source: the path, or the temporary file, to the PDF.
+ :return: the text of the PDF.
+ """
+ reader = PdfReader(pdf_source)
+ _pdf_text = ""
+ page_number = 0
+ for page in reader.pages:
+ page_number += 1
+ _pdf_text += page.extract_text() + f"\nPage Number: {page_number}"
+ return _pdf_text
+
+
+def _create_chunks(text: str, n: int, tokenizer: tiktoken.Encoding) -> Generator[str, None, None]:
+ """Helper function. Returns successive n-sized chunks from provided text.
+ Split a text into smaller chunks of size n, preferably ending at the end of a sentence
+ :param text:
+ :param n:
+ :param tokenizer:
+ :return:
+ """
+ tokens = tokenizer.encode(text)
+ i = 0
+ while i < len(tokens):
+ # Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens
+ j = min(i + int(1.5 * n), len(tokens))
+ while j > i + int(0.5 * n):
+ # Decode the tokens and check for full stop or newline
+ chunk = tokenizer.decode(tokens[i:j])
+ if chunk.endswith(".") or chunk.endswith("\n"):
+ break
+ j -= 1
+ # If no end of sentence found, use n tokens as the chunk size
+ if j == i + int(0.5 * n):
+ j = min(i + n, len(tokens))
+ yield tokens[i:j]
+ i = j
+
+
+def chunked_text(
+ raw_text: str, tokenizer_encoding: str = "cl100k_base", max_token_length: int = 1500
+) -> list[str]:
+ """Chunks the pdf text into smaller chunks of size max_token_length.
+ :param raw_text: the Series of individual pdf texts to chunk.
+ :param max_token_length: the maximum length of tokens in each chunk.
+ :param tokenizer_encoding: the encoding to use for the tokenizer.
+ :return: Series of chunked pdf text. Each element is a list of chunks.
+ """
+ tokenizer = tiktoken.get_encoding(tokenizer_encoding)
+ _encoded_chunks = _create_chunks(raw_text, max_token_length, tokenizer)
+ _decoded_chunks = [tokenizer.decode(chunk) for chunk in _encoded_chunks]
+ return _decoded_chunks
+
+
+@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
+def _summarize_chunk(content: str, template_prompt: str, openai_gpt_model: str) -> str:
+ """This function applies a prompt to some input content. In this case it returns a summarized chunk of text.
+ :param content: the content to summarize.
+ :param template_prompt: the prompt template to use to put the content into.
+ :param openai_gpt_model: the openai gpt model to use.
+ :return: the response from the openai API.
+ """
+ # NEED export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
+ prompt = template_prompt + content
+ response = openai.ChatCompletion.create(
+ model=openai_gpt_model, messages=[{"role": "user", "content": prompt}], temperature=0
+ )
+ return response["choices"][0]["message"]["content"]
+
+
+def summarized_chunks(
+ chunked_text: list[str], summarize_chunk_of_text_prompt: str, openai_gpt_model: str
+) -> str:
+ """Summarizes a series of chunks of text.
+ Note: this takes the first result from the top_n_related_articles series and summarizes it. This is because
+ the top_n_related_articles series is sorted by relatedness, so the first result is the most related.
+ :param chunked_text: a list of chunks of text for an article.
+ :param summarize_chunk_of_text_prompt: the prompt to use to summarize each chunk of text.
+ :param openai_gpt_model: the openai gpt model to use.
+ :return: a single string of each chunk of text summarized, concatenated together.
+ """
+ _summarized_text = ""
+ with concurrent.futures.ThreadPoolExecutor(max_workers=len(chunked_text)) as executor:
+ futures = [
+ executor.submit(
+ _summarize_chunk, chunk, summarize_chunk_of_text_prompt, openai_gpt_model
+ )
+ for chunk in chunked_text
+ ]
+ with tqdm(total=len(chunked_text)) as pbar:
+ for _ in concurrent.futures.as_completed(futures):
+ pbar.update(1)
+ for future in futures:
+ data = future.result()
+ _summarized_text += data
+ return _summarized_text
+
+
+def prompt_and_text_content(
+ summarized_chunks: str,
+ summarize_text_from_summaries_prompt: str,
+ user_query: str,
+) -> str:
+ """Creates the prompt for summarizing the text from the summarized chunks of the pdf.
+ :param summarized_chunks: a long string of chunked summaries of a file.
+ :param summarize_text_from_summaries_prompt: the template to use to summarize the chunks.
+ :param user_query: the original user query.
+ :return: the prompt to use to summarize the chunks.
+ """
+ return summarize_text_from_summaries_prompt.format(query=user_query, results=summarized_chunks)
+
+
+def summarized_text(
+ prompt_and_text_content: str,
+ openai_gpt_model: str,
+) -> str:
+ """Summarizes the text from the summarized chunks of the pdf.
+ :param prompt_and_text_content: the prompt and content to send over.
+ :param openai_gpt_model: which openai gpt model to use.
+ :return: the string response from the openai API.
+ """
+ response = openai.ChatCompletion.create(
+ model=openai_gpt_model,
+ messages=[
+ {
+ "role": "user",
+ "content": prompt_and_text_content,
+ }
+ ],
+ temperature=0,
+ )
+ return response["choices"][0]["message"]["content"]
+
+
+if __name__ == "__main__":
+ # run as a script to test Hamilton's execution
+ import summarization
+
+ from hamilton import base, driver
+
+ dr = driver.Driver(
+ {},
+ summarization,
+ adapter=base.SimplePythonGraphAdapter(base.DictResult()),
+ )
+ dr.display_all_functions("summary", {"format": "png"})
diff --git a/graph_adapter_tests/h_spark/test_h_spark.py b/graph_adapter_tests/h_spark/test_h_spark.py
index c3c5db195..feecf4c1a 100644
--- a/graph_adapter_tests/h_spark/test_h_spark.py
+++ b/graph_adapter_tests/h_spark/test_h_spark.py
@@ -1,8 +1,11 @@
+import sys
+
+import numpy as np
import pandas as pd
import pyspark.pandas as ps
import pytest
from pyspark import Row
-from pyspark.sql import SparkSession
+from pyspark.sql import SparkSession, types
from pyspark.sql.functions import column
from hamilton import base, driver, htypes, node
@@ -235,3 +238,115 @@ def test_smoke_screen_udf_graph_adatper(spark_session):
Row(a=2, b=5, base_func=7, base_func2=11, base_func3=9),
Row(a=3, b=6, base_func=9, base_func2=13, base_func3=9),
]
+
+
+# Test cases for python_to_spark_type function
+@pytest.mark.parametrize(
+ "python_type,expected_spark_type",
+ [
+ (int, types.IntegerType()),
+ (float, types.FloatType()),
+ (bool, types.BooleanType()),
+ (str, types.StringType()),
+ (bytes, types.BinaryType()),
+ ],
+)
+def test_python_to_spark_type_valid(python_type, expected_spark_type):
+ assert h_spark.python_to_spark_type(python_type) == expected_spark_type
+
+
+@pytest.mark.parametrize("invalid_python_type", [list, dict, tuple, set])
+def test_python_to_spark_type_invalid(invalid_python_type):
+ with pytest.raises(ValueError, match=f"Unsupported Python type: {invalid_python_type}"):
+ h_spark.python_to_spark_type(invalid_python_type)
+
+
+# Test cases for get_spark_type function
+# 1. Basic Python types
+@pytest.mark.parametrize(
+ "return_type,expected_spark_type",
+ [
+ (int, types.IntegerType()),
+ (float, types.FloatType()),
+ (bool, types.BooleanType()),
+ (str, types.StringType()),
+ (bytes, types.BinaryType()),
+ ],
+)
+def test_get_spark_type_basic_types(
+ dummy_kwargs, dummy_df, dummy_udf, return_type, expected_spark_type
+):
+ assert (
+ h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, return_type)
+ == expected_spark_type
+ )
+
+
+# 2. Lists of basic Python types
+@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 or higher")
+@pytest.mark.parametrize(
+ "return_type,expected_spark_type",
+ [
+ (int, types.ArrayType(types.IntegerType())),
+ (float, types.ArrayType(types.FloatType())),
+ (bool, types.ArrayType(types.BooleanType())),
+ (str, types.ArrayType(types.StringType())),
+ (bytes, types.ArrayType(types.BinaryType())),
+ ],
+)
+def test_get_spark_type_list_types(
+ dummy_kwargs, dummy_df, dummy_udf, return_type, expected_spark_type
+):
+ return_type = list[return_type] # type: ignore
+ assert (
+ h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, return_type)
+ == expected_spark_type
+ )
+
+
+# 3. Numpy types (assuming you have a numpy_to_spark_type function that handles these)
+@pytest.mark.parametrize(
+ "return_type,expected_spark_type",
+ [
+ (np.int64, types.IntegerType()),
+ (np.float64, types.FloatType()),
+ (np.bool_, types.BooleanType()),
+ ],
+)
+def test_get_spark_type_numpy_types(
+ dummy_kwargs, dummy_df, dummy_udf, return_type, expected_spark_type
+):
+ assert (
+ h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, return_type)
+ == expected_spark_type
+ )
+
+
+# 4. Unsupported types
+@pytest.mark.parametrize(
+ "unsupported_return_type", [dict, set, tuple] # Add other unsupported types as needed
+)
+def test_get_spark_type_unsupported(dummy_kwargs, dummy_df, dummy_udf, unsupported_return_type):
+ with pytest.raises(
+ ValueError, match=f"Currently unsupported return type {unsupported_return_type}."
+ ):
+ h_spark.get_spark_type(dummy_kwargs, dummy_df, dummy_udf, unsupported_return_type)
+
+
+# Dummy values for the tests
+@pytest.fixture
+def dummy_kwargs():
+ return {}
+
+
+@pytest.fixture
+def dummy_df():
+ return spark.createDataFrame(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
+
+
+@pytest.fixture
+def dummy_udf():
+ def dummyfunc(x: int) -> int:
+ return x
+
+ return dummyfunc
diff --git a/hamilton/experimental/h_spark.py b/hamilton/experimental/h_spark.py
index 7f1de7d5e..0883fb692 100644
--- a/hamilton/experimental/h_spark.py
+++ b/hamilton/experimental/h_spark.py
@@ -1,7 +1,8 @@
import functools
import inspect
import logging
-from typing import Any, Callable, Dict, Set, Tuple, Type, Union
+import sys
+from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
import numpy as np
import pandas as pd
@@ -10,6 +11,7 @@
from pyspark.sql.functions import column, lit, pandas_udf, udf
from hamilton import base, htypes, node
+from hamilton.node import DependencyType
logger = logging.getLogger(__name__)
@@ -177,7 +179,7 @@ def numpy_to_spark_type(numpy_type: Type) -> types.DataType:
raise ValueError("Unsupported NumPy type: " + str(numpy_type))
-def python_to_spark_type(python_type: Union[int, float, bool, str, bytes]) -> types.DataType:
+def python_to_spark_type(python_type: Type[Union[int, float, bool, str, bytes]]) -> types.DataType:
"""Function to convert a Python type to a Spark type.
:param python_type: the Python type to convert.
@@ -198,11 +200,19 @@ def python_to_spark_type(python_type: Union[int, float, bool, str, bytes]) -> ty
raise ValueError("Unsupported Python type: " + str(python_type))
+if sys.version_info < (3, 9):
+ _list = (List[int], List[float], List[bool], List[str], List[bytes])
+else:
+ _list = (list[int], list[float], list[bool], list[str], list[bytes])
+
+
def get_spark_type(
actual_kwargs: dict, df: DataFrame, hamilton_udf: Callable, return_type: Any
) -> types.DataType:
if return_type in (int, float, bool, str, bytes):
return python_to_spark_type(return_type)
+ elif return_type in _list:
+ return types.ArrayType(python_to_spark_type(return_type.__args__[0]))
elif hasattr(return_type, "__module__") and getattr(return_type, "__module__") == "numpy":
return numpy_to_spark_type(return_type)
else:
@@ -259,6 +269,8 @@ def _bind_parameters_to_callable(
hamilton_udf = functools.partial(
hamilton_udf, **{input_name: actual_kwargs[input_name]}
)
+ elif node_input_types[input_name][1] == DependencyType.OPTIONAL:
+ pass
else:
raise ValueError(
f"Cannot satisfy {node_name} with input types {node_input_types} against a dataframe with "