Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds more to parallel caveats #746

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 132 additions & 2 deletions docs/concepts/parallel-task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ Note that we currently have the following caveats:
3. `Collect[]` input types are limited to one per function -- this is another caveat that we intend to get rid of, but for now you'll want to concat/put into one function before collecting.

Known Caveats
^^^^^^^^^^^^^
=============
If you're familiar with multi-processing then these caveats will be familiar to you. If not, then you should be aware of the following:

**Serialization**:
Serialization
^^^^^^^^^^^^^

Challenge:

Expand All @@ -94,3 +95,132 @@ Solution:
* Another option is write a customer wrapper function that uses `__set_state__` and `__get_state__` to serialize and deserialize the object.
* See [this issue](https://github.com/DAGWorks-Inc/hamilton/issues/743) for details and possible features to make
this simpler to deal with.


Multiple Collects
^^^^^^^^^^^^^^^^^

Currently, by design (see all limitations `here <https://github.com/DAGWorks-Inc/hamilton/issues/301>`_), you can only have one "collect" downstream of "parallel".

So the following code WILL NOT WORK:

.. code-block:: python

import logging

from hamilton import driver
from hamilton.execution.executors import SynchronousLocalTaskExecutor
from hamilton.htypes import Collect, Parallelizable
import pandas as pd


ANALYSIS_OB = tuple[tuple[str,...], pd.DataFrame]
ANALYSIS_RES = dict[str, str | float]


def split_by_cols(full_data: pd.DataFrame, columns: list[str]) -> Parallelizable[ANALYSIS_OB]:
for idx, grp in full_data.groupby(columns):
yield (idx, grp)


def sub_metric_1(split_by_cols: ANALYSIS_OB, number: float=1.0) -> ANALYSIS_RES:
idx, grp = split_by_cols
return {"key": idx, "mean": grp["spend"].mean() + number}


def sub_metric_2(split_by_cols: ANALYSIS_OB) -> ANALYSIS_RES:
idx, grp = split_by_cols
return {"key": idx, "mean": grp["signups"].mean()}


def metric_1(sub_metric_1: Collect[ANALYSIS_RES], columns: list[str]) -> pd.DataFrame:
data = [[k for k in d["key"]] + [d["mean"], "spend"] for d in sub_metric_1]
cols = list(columns) + ["mean", "metric"]
return pd.DataFrame(data, columns=cols)


def metric_2(sub_metric_2: Collect[ANALYSIS_RES], columns: list[str]) -> pd.DataFrame:
data = [[k for k in d["key"]] + [d["mean"], "signups"] for d in sub_metric_2]
cols = list(columns) + ["mean", "metric"]
return pd.DataFrame(data, columns=cols)


# this will not work because you can't have two Collect[] calls downstream from a Parallelizable[] call
def all_agg(metric_1: pd.DataFrame, metric_2: pd.DataFrame) -> pd.DataFrame:
return pd.concat([metric_1, metric_2])


if __name__ == "__main__":
from hamilton.execution import executors
import __main__

from hamilton.log_setup import setup_logging
setup_logging(log_level=logging.DEBUG)

local_executor = executors.SynchronousLocalTaskExecutor()

dr = (
driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(__main__)
.with_remote_executor(local_executor)
.build()
)
df = pd.DataFrame(
index=pd.date_range('20230101', '20230110'),
data={
"signups": [1, 10, 50, 100, 200, 400, 700, 800, 1000, 1300],
"spend": [10, 10, 20, 40, 40, 50, 100, 80, 90, 120],
"region": ["A", "B", "C", "A", "B", "C", "A", "B", "C", "X"],
}
)
ans = dr.execute(
["all_agg"],
inputs={
"full_data": df,
"number": 3.1,
"columns": ["region"],
}
)
print(ans["all_agg"])


To fix this, (this is documented in this `issue <https://github.com/DAGWorks-Inc/hamilton/issues/742>`_) you can either create a new function that combines the two `Collect[]` calls that could be combined with
:doc:`@config.when <../reference/decorators/config_when>`.

.. code-block:: python

def all_metrics(sub_metric_1: ANALYSIS_RES, sub_metric_2: ANALYSIS_RES) -> ANALYSIS_RES:
return ... # join the two dicts in whatever way you want

def all_agg(all_metrics: Collect[ANALYSIS_RES]) -> pd.DataFrame:
return ... # join them all into a dataframe

Or you use :doc:`@resolve <../reference/decorators/resolve>`,
with :doc:`@group (scroll down a little) <../reference/decorators/parameterize>`,
:doc:`@inject <../reference/decorators/inject>`,
to set what should be determined to be collected at DAG construction time:

.. code-block:: python

@resolve(
when=ResolveAt.CONFIG_AVAILABLE,
decorate_with= lambda metric_names:
inject( # this will annotate the function with @inject
# it will then inject a group of values corresponding to the sources wanted
sub_metrics=group(*[source(x) for x in metric_names])
),
)
def all_metrics(sub_metrics: list[ANALYSIS_RES], columns: list[str]) -> pd.DataFrame:
frames = []
for a in sub_metrics:
frames.append(_to_frame(a, columns))
return pd.concat(frames)

# then in your driver:
from hamilton import settings
_config = {settings.ENABLE_POWER_USER_MODE:True}
_config["metric_names"] = ["sub_metric_1", "sub_metric_2"]

# Then in the driver building pass in the configuration:
.with_config(_config)
2 changes: 1 addition & 1 deletion docs/reference/decorators/parameterize.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ documentation. If you don't, it will use the parameterized docstring.

**Reference Documentation**

Classes to help with @parameterize:
Classes to help with @parameterize (also can be used with :doc:`@inject <inject>`):

.. autoclass:: hamilton.function_modifiers.ParameterizedExtract

Expand Down