diff --git a/docs/conf.py b/docs/conf.py index ea274db7d..35f790006 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,6 +30,7 @@ "sphinx.ext.autosummary", "myst_parser", "sphinx_sitemap", + "docs.data_adapters_extension", ] # for the sitemap extension --- diff --git a/docs/data_adapters_extension.py b/docs/data_adapters_extension.py new file mode 100644 index 000000000..923c78cfb --- /dev/null +++ b/docs/data_adapters_extension.py @@ -0,0 +1,288 @@ +import dataclasses +import inspect +import os +from typing import List, Optional, Tuple, Type + +import git +from docutils import nodes +from docutils.parsers.rst import Directive + +import hamilton.io.data_adapters +from hamilton import registry + +"""A module to crawl available data adapters and generate documentation for them. +Note these currently link out to the source code on GitHub, but they should +be linking to the documentation instead, which hasn't been generated yet. +""" + +# These have fallbacks for local dev +GIT_URL = os.environ.get("READTHEDOCS_GIT_CLONE_URL", "https://github.com/dagworks-inc/hamilton") +GIT_ID = os.environ.get("READTHEDOCS_GIT_IDENTIFIER", "main") + +# All the modules that register data adapters +# When you register a new one, add it here +MODULES_TO_IMPORT = ["hamilton.io.default_data_loaders", "hamilton.plugins.pandas_extensions"] + +for module in MODULES_TO_IMPORT: + __import__(module) + + +def get_git_root(path: str) -> str: + """Yields the git room of a repo, given an absolute path to + a file within the repo. + + :param path: Path to a file within a git repo + :return: The root of the git repo + """ + git_repo = git.Repo(path, search_parent_directories=True) + git_root = git_repo.git.rev_parse("--show-toplevel") + return git_root + + +@dataclasses.dataclass +class Param: + name: str + type: str + default: Optional[str] = None + + +def get_default(param: dataclasses.Field) -> Optional[str]: + """Gets the deafult of a dataclass field, if it has one. + + :param param: The dataclass field + :return: The str representation of the default. + """ + if param.default is dataclasses.MISSING: + return None + return str(param.default) + + +def get_lines_for_class(class_: Type[Type]) -> Tuple[int, int]: + """Gets the set of lines in which a class is implemented + + :param class_: The class to get the lines for + :return: A tuple of the start and end lines + """ + lines = inspect.getsourcelines(class_) + start_line = lines[1] + end_line = lines[1] + len(lines[0]) + return start_line, end_line + + +def get_class_repr(class_: Type) -> str: + """Gets a representation of a class that can be used in documentation. + + :param class_: Python class to get the representation for + :return: Str representation + """ + + try: + return class_.__qualname__ + except AttributeError: + # This happens when we have generics or other oddities + return str(class_) + + +@dataclasses.dataclass +class AdapterInfo: + key: str + class_name: str + class_path: str + load_params: List[Param] + save_params: List[Param] + applicable_types: List[str] + file_: str + line_nos: Tuple[int, int] + + @staticmethod + def from_loader(loader: Type[hamilton.io.data_adapters.DataLoader]) -> "AdapterInfo": + """Utility constructor to create the AdapterInfo from a DataLoader class + + :param loader: DataLoader class + :return: AdapterInfo derived from it + """ + + return AdapterInfo( + key=loader.name(), + class_name=loader.__name__, + class_path=loader.__module__, + load_params=[ + Param(name=p.name, type=get_class_repr(p.type), default=get_default(p)) + for p in dataclasses.fields(loader) + ] + if issubclass(loader, hamilton.io.data_adapters.DataSaver) + else None, + save_params=[ + Param(name=p.name, type=get_class_repr(p.type), default=get_default(p)) + for p in dataclasses.fields(loader) + ] + if issubclass(loader, hamilton.io.data_adapters.DataSaver) + else None, + applicable_types=[get_class_repr(t) for t in loader.applicable_types()], + file_=inspect.getfile(loader), + line_nos=get_lines_for_class(loader), + ) + + +def _collect_loaders(saver_or_loader: str) -> List[Type[hamilton.io.data_adapters.AdapterCommon]]: + """Collects all loaders from the registry. + + :return: + """ + out = [] + loaders = ( + list(registry.LOADER_REGISTRY.values()) + if saver_or_loader == "loader" + else list(registry.SAVER_REGISTRY.values()) + ) + for classes in loaders: + for cls in classes: + if cls not in out: + out.append(cls) + return out + + +# Utility functions to render different components of the adapter in table cells + + +def render_key(key: str): + return [nodes.Text(key, key)] + + +def render_class_name(class_name: str): + return [nodes.literal(text=class_name)] + + +def render_class_path(class_path: str, file_: str, line_start: int, line_end: int): + git_path = get_git_root(file_) + file_relative_to_git_root = os.path.relpath(file_, git_path) + href = f"{GIT_URL}/blob/{GIT_ID}/{file_relative_to_git_root}#L{line_start}-L{line_end}" + # href = f"{GIT_URL}/blob/{GIT_ID}/{file_}#L{line_no}" + return [nodes.raw("", f'{class_path}', format="html")] + + +def render_adapter_params(load_params: Optional[List[Param]]): + if load_params is None: + return nodes.raw("", "
", format="html") + fieldlist = nodes.field_list() + for i, load_param in enumerate(load_params): + fieldname = nodes.Text(load_param.name) + fieldbody = nodes.literal( + text=load_param.type + + ("=" + load_param.default if load_param.default is not None else "") + ) + field = nodes.field("", fieldname, fieldbody) + fieldlist += field + if i < len(load_params) - 1: + fieldlist += nodes.raw("", "
", format="html") + return fieldlist + + +def render_applicable_types(applicable_types: List[str]): + fieldlist = nodes.field_list() + for applicable_type in applicable_types: + fieldlist += nodes.field("", nodes.literal(text=applicable_type), nodes.Text("")) + fieldlist += nodes.raw("", "
", format="html") + return fieldlist + + +class DataAdapterTableDirective(Directive): + """Custom directive to render a table of all data adapters. Takes in one argument + that is either 'loader' or 'saver' to indicate which adapters to render.""" + + has_content = True + required_arguments = 1 # Number of required arguments + + def run(self): + """Runs the directive. This does the following: + 1. Collects all loaders from the registry + 2. Creates a table with the following columns: + - Key + - Class name + - Class path + - Load params + - Applicable types + 3. Returns the table + :return: A list of nodes that Sphinx will render, consisting of the table node + """ + saver_or_loader = self.arguments[0] + if saver_or_loader not in ("loader", "saver"): + raise ValueError( + f"loader_or_saver must be one of 'loader' or 'saver', " f"got {saver_or_loader}" + ) + table_data = [ + AdapterInfo.from_loader(loader) for loader in _collect_loaders(saver_or_loader) + ] + + # Create the table and add columns + table_node = nodes.table() + tgroup = nodes.tgroup(cols=6) + table_node += tgroup + + # Create columns + key_spec = nodes.colspec(colwidth=1) + # class_spec = nodes.colspec(colwidth=1) + load_params_spec = nodes.colspec(colwidth=2) + applicable_types_spec = nodes.colspec(colwidth=1) + class_path_spec = nodes.colspec(colwidth=1) + + tgroup += [key_spec, load_params_spec, applicable_types_spec, class_path_spec] + + # Create the table body + thead = nodes.thead() + row = nodes.row() + + # Create entry nodes for each cell + key_entry = nodes.entry() + load_params_entry = nodes.entry() + applicable_types_entry = nodes.entry() + class_path_entry = nodes.entry() + + key_entry += nodes.paragraph(text="key") + + load_params_entry += nodes.paragraph(text=f"{saver_or_loader} params") + applicable_types_entry += nodes.paragraph(text="types") + class_path_entry += nodes.paragraph(text="module") + + row += [key_entry, load_params_entry, applicable_types_entry, class_path_entry] + thead += row + tgroup += thead + tbody = nodes.tbody() + tgroup += tbody + + # Populate table rows based on your table_data + for row_data in table_data: + row = nodes.row() + + # Create entry nodes for each cell + key_entry = nodes.entry() + load_params_entry = nodes.entry() + applicable_types_entry = nodes.entry() + class_path_entry = nodes.entry() + + # Create a paragraph node for each entry + # import pdb + # pdb.set_trace() + # para1 = nodes.literal(text=row_data['column1_data']) + # para2 = nodes.paragraph(text=row_data['column2_data']) + + # Add the paragraph nodes to the entry nodes + key_entry += render_key(row_data.key) + load_params_entry += render_adapter_params(row_data.load_params) + applicable_types_entry += render_applicable_types(row_data.applicable_types) + class_path_entry += render_class_path( + row_data.class_path, row_data.file_, *row_data.line_nos + ) + + # Add the entry nodes to the row + row += [key_entry, load_params_entry, applicable_types_entry, class_path_entry] + + # Add the row to the table body + tbody += row + + return [table_node] + + +def setup(app): + """Required to register the extension""" + app.add_directive("data_adapter_table", DataAdapterTableDirective) diff --git a/docs/index.md b/docs/index.md index b82c522e8..ee17efafe 100644 --- a/docs/index.md +++ b/docs/index.md @@ -41,6 +41,7 @@ contributing reference/decorators/index reference/drivers/index +reference/io/index reference/graph-adapters/index reference/result-builders/index reference/miscellaneous/index diff --git a/docs/reference/io/adapter-documentation.rst b/docs/reference/io/adapter-documentation.rst new file mode 100644 index 000000000..6394a4c02 --- /dev/null +++ b/docs/reference/io/adapter-documentation.rst @@ -0,0 +1,20 @@ +========================= +Data Adapters +========================= + +Reference for data adapter base classes: + +.. autoclass:: hamilton.io.data_adapters.DataLoader + :special-members: __init__ + :members: + :inherited-members: + +.. autoclass:: hamilton.io.data_adapters.DataSaver + :special-members: __init__ + :members: + :inherited-members: + +.. autoclass:: hamilton.io.data_adapters.AdapterCommon + :special-members: __init__ + :members: + :inherited-members: diff --git a/docs/reference/io/available-data-adapters.rst b/docs/reference/io/available-data-adapters.rst new file mode 100644 index 000000000..a8c5a2469 --- /dev/null +++ b/docs/reference/io/available-data-adapters.rst @@ -0,0 +1,56 @@ +======================== +Using Data Adapters +======================== + +This is an index of all the available data adapters, both savers and loaders. +Note that some savers and loaders are the same (certain classes can handle both), +but some are different. You will want to reference this when calling out to any of the following: + +1. Using :doc:`/reference/decorators/save_to/`. +2. Using :doc:`/reference/decorators/load_from/`. +3. Using :doc:`materialize ` + +To read these tables, you want to first look at the key to determine which format you want -- +these should be human-readable and familiar to you. Then you'll want to look at the `types` field +to figure out which is the best for your case (the object you want to load from or save to). + +Finally, look up the adapter params to see what parameters you can pass to the data adapters. +The optional params come with their default value specified. + +If you want more information, click on the `module`, it will send you to the code that implements +it to see how the parameters are used. + +As an example, say we wanted to save a pandas dataframe to a CSV file. We would first find the +key `csv`, which would inform us that we want to call `save_to.csv` (or `to.csv` in the case +of `materialize`). Then, we would look at the `types` field, finding that there is a pandas +dataframe adapter. Finally, we would look at the `params` field, finding that we can pass +`path`, and (optionally) `sep` (which we'd realize defaults to `,` when looking at the code). + +All together, we'd end up with: + +.. code-block:: python + + import pandas as pd + from hamilton.function_modifiers import value, save_to + + @save_to.csv(path=value("my_file.csv")) + def my_data(...) -> pd.DataFrame: + ... + +And we're good to go! + +If you want to extend these, see :doc:`/reference/io/available-data-adapters` for documentation, +and `the example `_ +in the repository for an example of how to do so. + +============= +Data Loaders +============= + +.. data_adapter_table:: loader + +============= +Data Savers +============= + +.. data_adapter_table:: saver diff --git a/docs/reference/io/index.rst b/docs/reference/io/index.rst new file mode 100644 index 000000000..637a69c03 --- /dev/null +++ b/docs/reference/io/index.rst @@ -0,0 +1,11 @@ +============== +I/O +============== + +This section contains any information about I/O within Hamilton + +.. toctree:: + :maxdepth: 2 + + available-data-adapters + adapter-documentation diff --git a/examples/materialization/README.md b/examples/materialization/README.md new file mode 100644 index 000000000..3a6c73e8d --- /dev/null +++ b/examples/materialization/README.md @@ -0,0 +1,66 @@ +# Materialization + +Hamilton's driver allows for ad-hoc materialization. This enables you to take a DAG you already have, +and save your data to a set of custom locations/url. + +Note that these materializers are _isomorphic_ in nature to the +[@save_to](https://hamilton.dagworks.io/en/latest/reference/decorators/save_to/) +decorator. Materializers inject the additional node at runtime, modifying the +DAG to include a data saver node, and returning the metadata around materialization. + +This framework is meant to be highly pluggable. While the set of available data savers is currently +limited, we expect folks to build their own materializers (and, hopefully, contribute them back to the community!). + + +## example +In this example we take the scikit-learn iris_loader pipeline, and materialize outputs to specific +locations through a driver call. We demonstrate: + +1. Saving model parameters to a json file (using the default json materializer) +2. Writing a custom data adapters for: + 1. Pickling a model to an object file + 2. Saving confusion matrices to a csv file + +See [run.py](run.py) for the full example. + +In this example we only pass literal values to the materializers. That said, you can use both `source` (to specify the source from an upstream node), +and `value` (which is the default) to specify literals. + + +## `driver.materialize` + +This will be a high-level overview. For more details, +see [documentation](https://hamilton.dagworks.io/en/latest/reference/drivers/Driver/#hamilton.driver.Driver.materializehttps://hamilton.dagworks.io/en/latest/reference/drivers/Driver/#hamilton.driver.Driver.materialize). + +`driver.materialize()` does the following: +1. Processes a list of materializers to create a new DAG +2. Alters the output to include the materializer nodes +3. Processes a list of "additional variables" (for debugging) to return intermediary data +4. Executes the DAG, including the materializers +5. Returns a tuple of (`materialization metadata`, `additional variables`) + +Materializers each consume: +1. A `dependencies` list to materialize +2. A (optional) `combine` parameter to combine the outputs of the dependencies +(this is required if there are multiple dependencies). This is a [ResultMixin](https://hamilton.dagworks.io/en/latest/concepts/customizing-execution/#result-builders) object +3. an `id` parameter to identify the materializer, which serves as the nde name in the DAG + +Materializers are referenced by the `to` object in `hamilton.io.materialization`, which utilizes +dynamic dispatch to create the appropriate materializer. + +These refer to a `DataSaver`, which are keyed by a string (E.G `csv`). +Multiple data adapters can share the same key, each of which applies to a specific type +(E.G. pandas dataframe, numpy matrix, polars dataframe). New +data adapters are registered by calling `hamilton.registry.register_adapter` + +## Custom Materializers + +To define a custom materializer, all you have to do is implement the `DataSaver` class +(which will allow use in `save_to` as well.) This is demonstrated in [custom_materializers.py](custom_materializers.py). + +## `driver.materialize` vs `@save_to` + +`driver.materialize` is an ad-hoc form of `save_to`. You want to use this when you're developing, and +want to do ad-hoc materialization. When you have a production ETL, you can choose between `save_to` and `materialize`. +If the save location/structure is unlikely to change, then you might consider using `save_to`. Otherwise, `materialize` +is an idiomatic way of conducting the maerialization operations that cleanly separates side-effects from transformations. diff --git a/examples/materialization/custom_materializers.py b/examples/materialization/custom_materializers.py new file mode 100644 index 000000000..a2508ae43 --- /dev/null +++ b/examples/materialization/custom_materializers.py @@ -0,0 +1,55 @@ +import dataclasses +import pickle +from typing import Any, Collection, Dict, Type + +import numpy as np +from sklearn import base + +from hamilton import registry +from hamilton.io import utils +from hamilton.io.data_adapters import DataSaver + +# TODO -- put this back in the standard library + + +@dataclasses.dataclass +class NumpyMatrixToCSV(DataSaver): + path: str + sep: str = "," + + def __post_init__(self): + if not self.path.endswith(".csv"): + raise ValueError(f"CSV files must end with .csv, got {self.path}") + + def save_data(self, data: np.ndarray) -> Dict[str, Any]: + np.savetxt(self.path, data, delimiter=self.sep) + return utils.get_file_metadata(self.path) + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [np.ndarray] + + @classmethod + def name(cls) -> str: + return "csv" + + +@dataclasses.dataclass +class SKLearnPickler(DataSaver): + path: str + + def save_data(self, data: base.ClassifierMixin) -> Dict[str, Any]: + pickle.dump(data, open(self.path, "wb")) + return utils.get_file_metadata(self.path) + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [base.ClassifierMixin] + + @classmethod + def name(cls) -> str: + return "pickle" + + +for adapter in [NumpyMatrixToCSV, SKLearnPickler]: + registry.register_adapter(adapter) diff --git a/examples/materialization/dag.pdf b/examples/materialization/dag.pdf new file mode 100644 index 000000000..3d123bb07 Binary files /dev/null and b/examples/materialization/dag.pdf differ diff --git a/examples/materialization/data_loaders.py b/examples/materialization/data_loaders.py new file mode 100644 index 000000000..e31061eab --- /dev/null +++ b/examples/materialization/data_loaders.py @@ -0,0 +1,30 @@ +import numpy as np +from sklearn import datasets, utils + +from hamilton.function_modifiers import config + +""" +Module to load digit data. +""" + + +@config.when(data_loader="iris") +def data__iris() -> utils.Bunch: + return datasets.load_digits() + + +@config.when(data_loader="digits") +def data__digits() -> utils.Bunch: + return datasets.load_digits() + + +def target(data: utils.Bunch) -> np.ndarray: + return data.target + + +def target_names(data: utils.Bunch) -> np.ndarray: + return data.target_names + + +def feature_matrix(data: utils.Bunch) -> np.ndarray: + return data.data diff --git a/examples/materialization/model_training.py b/examples/materialization/model_training.py new file mode 100644 index 000000000..894cc533a --- /dev/null +++ b/examples/materialization/model_training.py @@ -0,0 +1,96 @@ +from typing import Dict + +import numpy as np +from sklearn import base, linear_model, metrics, svm +from sklearn.model_selection import train_test_split + +from hamilton import function_modifiers + + +@function_modifiers.config.when(clf="svm") +def prefit_clf__svm(gamma: float = 0.001) -> base.ClassifierMixin: + """Returns an unfitted SVM classifier object. + + :param gamma: ... + :return: + """ + return svm.SVC(gamma=gamma) + + +@function_modifiers.config.when(clf="logistic") +def prefit_clf__logreg(penalty: str) -> base.ClassifierMixin: + """Returns an unfitted Logistic Regression classifier object. + + :param penalty: + :return: + """ + return linear_model.LogisticRegression(penalty) + + +@function_modifiers.extract_fields( + {"X_train": np.ndarray, "X_test": np.ndarray, "y_train": np.ndarray, "y_test": np.ndarray} +) +def train_test_split_func( + feature_matrix: np.ndarray, + target: np.ndarray, + test_size_fraction: float, + shuffle_train_test_split: bool, +) -> Dict[str, np.ndarray]: + """Function that creates the training & test splits. + + It this then extracted out into constituent components and used downstream. + + :param feature_matrix: + :param target: + :param test_size_fraction: + :param shuffle_train_test_split: + :return: + """ + X_train, X_test, y_train, y_test = train_test_split( + feature_matrix, target, test_size=test_size_fraction, shuffle=shuffle_train_test_split + ) + return {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test} + + +def y_test_with_labels(y_test: np.ndarray, target_names: np.ndarray) -> np.ndarray: + """Adds labels to the target output.""" + return np.array([target_names[idx] for idx in y_test]) + + +def fit_clf( + prefit_clf: base.ClassifierMixin, X_train: np.ndarray, y_train: np.ndarray +) -> base.ClassifierMixin: + """Calls fit on the classifier object; it mutates it.""" + prefit_clf.fit(X_train, y_train) + return prefit_clf + + +def predicted_output(fit_clf: base.ClassifierMixin, X_test: np.ndarray) -> np.ndarray: + """Exercised the fit classifier to perform a prediction.""" + return fit_clf.predict(X_test) + + +def predicted_output_with_labels( + predicted_output: np.ndarray, target_names: np.ndarray +) -> np.ndarray: + """Replaces the predictions with the desired labels.""" + return np.array([target_names[idx] for idx in predicted_output]) + + +def classification_report( + predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray +) -> str: + """Returns a classification report.""" + return metrics.classification_report(y_test_with_labels, predicted_output_with_labels) + + +def confusion_matrix( + predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray +) -> str: + """Returns a confusion matrix report.""" + return metrics.confusion_matrix(y_test_with_labels, predicted_output_with_labels) + + +def model_parameters(fit_clf: base.ClassifierMixin) -> dict: + """Returns a dictionary of model parameters.""" + return fit_clf.get_params() diff --git a/examples/materialization/notebook.ipynb b/examples/materialization/notebook.ipynb new file mode 100644 index 000000000..0a414930a --- /dev/null +++ b/examples/materialization/notebook.ipynb @@ -0,0 +1,535 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 8, + "id": "7bf6a40d", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "\n", + "import data_loaders\n", + "import model_training\n", + "\n", + "from hamilton import base, driver\n", + "from hamilton.io.materialization import to\n", + "import pandas as pd\n", + "\n", + "import custom_materializers" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7a449245", + "metadata": {}, + "outputs": [], + "source": [ + "dag_config = {\n", + " \"test_size_fraction\": 0.5,\n", + " \"shuffle_train_test_split\": True,\n", + " \"data_loader\" : \"iris\",\n", + " \"clf\" : \"logistic\",\n", + " \"penalty\" : \"l2\"\n", + "}\n", + "dr = (\n", + " driver.Builder()\n", + " .with_adapter(base.DefaultAdapter())\n", + " .with_config(dag_config)\n", + " .with_modules(data_loaders, model_training)\n", + " .build()\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "397b09bc", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "predicted_output_with_labels\n", + "\n", + "predicted_output_with_labels\n", + "\n", + "\n", + "\n", + "predicted_output_with_labels_to_csv\n", + "\n", + "predicted_output_with_labels_to_csv\n", + "\n", + "\n", + "\n", + "predicted_output_with_labels->predicted_output_with_labels_to_csv\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "classification_report\n", + "\n", + "classification_report\n", + "\n", + "\n", + "\n", + "predicted_output_with_labels->classification_report\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "train_test_split_func\n", + "\n", + "train_test_split_func\n", + "\n", + "\n", + "\n", + "y_test\n", + "\n", + "y_test\n", + "\n", + "\n", + "\n", + "train_test_split_func->y_test\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "X_train\n", + "\n", + "X_train\n", + "\n", + "\n", + "\n", + "train_test_split_func->X_train\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "X_test\n", + "\n", + "X_test\n", + "\n", + "\n", + "\n", + "train_test_split_func->X_test\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y_train\n", + "\n", + "y_train\n", + "\n", + "\n", + "\n", + "train_test_split_func->y_train\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "fit_clf\n", + "\n", + "fit_clf\n", + "\n", + "\n", + "\n", + "predicted_output\n", + "\n", + "predicted_output\n", + "\n", + "\n", + "\n", + "fit_clf->predicted_output\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "clf_to_pickle\n", + "\n", + "clf_to_pickle\n", + "\n", + "\n", + "\n", + "fit_clf->clf_to_pickle\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "model_parameters\n", + "\n", + "model_parameters\n", + "\n", + "\n", + "\n", + "fit_clf->model_parameters\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "data\n", + "\n", + "data\n", + "\n", + "\n", + "\n", + "target_names\n", + "\n", + "target_names\n", + "\n", + "\n", + "\n", + "data->target_names\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "feature_matrix\n", + "\n", + "feature_matrix\n", + "\n", + "\n", + "\n", + "data->feature_matrix\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "target\n", + "\n", + "target\n", + "\n", + "\n", + "\n", + "data->target\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "penalty\n", + "\n", + "Input: penalty\n", + "\n", + "\n", + "\n", + "prefit_clf\n", + "\n", + "prefit_clf\n", + "\n", + "\n", + "\n", + "penalty->prefit_clf\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "predicted_output->predicted_output_with_labels\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "shuffle_train_test_split\n", + "\n", + "Input: shuffle_train_test_split\n", + "\n", + "\n", + "\n", + "shuffle_train_test_split->train_test_split_func\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "target_names->predicted_output_with_labels\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y_test_with_labels\n", + "\n", + "y_test_with_labels\n", + "\n", + "\n", + "\n", + "target_names->y_test_with_labels\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y_test->y_test_with_labels\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "X_train->fit_clf\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "model_params_to_json\n", + "\n", + "model_params_to_json\n", + "\n", + "\n", + "\n", + "feature_matrix->train_test_split_func\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "classification_report_to_txt\n", + "\n", + "classification_report_to_txt\n", + "\n", + "\n", + "\n", + "classification_report->classification_report_to_txt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "prefit_clf->fit_clf\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "X_test->predicted_output\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y_test_with_labels->classification_report\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "test_size_fraction\n", + "\n", + "Input: test_size_fraction\n", + "\n", + "\n", + "\n", + "test_size_fraction->train_test_split_func\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "model_parameters->model_params_to_json\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y_train->fit_clf\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "target->train_test_split_func\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "materializers = [\n", + " to.json(\n", + " dependencies=[\"model_parameters\"],\n", + " id=\"model_params_to_json\",\n", + " path=\"./data/params.json\"\n", + " ),\n", + " # classification report to .txt file\n", + " to.file(\n", + " dependencies=[\"classification_report\"],\n", + " id=\"classification_report_to_txt\",\n", + " path=\"./data/classification_report.txt\",\n", + " ),\n", + " # materialize the model to a pickle file\n", + " to.pickle(\n", + " dependencies=[\"fit_clf\"], id=\"clf_to_pickle\", path=\"./data/clf.pkl\"\n", + " ),\n", + " # materialize the predictions we made to a csv file\n", + " to.csv(\n", + " dependencies=[\"predicted_output_with_labels\"],\n", + " id=\"predicted_output_with_labels_to_csv\",\n", + " path=\"./data/predicted_output_with_labels.csv\",\n", + " ),\n", + " ]\n", + "\n", + "dr.visualize_materialization(\n", + " *materializers,\n", + " additional_vars=[\"classification_report\"],\n", + " output_file_path=None,\n", + " render_kwargs={},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "f5727b54", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/elijahbenizzy/.pyenv/versions/3.9.10/envs/hamilton/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + } + ], + "source": [ + "materialization_results, additional_vars = dr.materialize(\n", + " # materialize model parameters to json\n", + " *materializers,\n", + " additional_vars=[\"classification_report\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8bdfde70", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 1.00 1.00 1.00 94\n", + " 1 0.91 0.93 0.92 85\n", + " 2 0.97 0.99 0.98 96\n", + " 3 0.99 0.97 0.98 93\n", + " 4 0.99 0.92 0.95 88\n", + " 5 0.95 0.95 0.95 85\n", + " 6 0.99 0.97 0.98 97\n", + " 7 0.97 0.97 0.97 89\n", + " 8 0.88 0.88 0.88 82\n", + " 9 0.91 0.97 0.94 90\n", + "\n", + " accuracy 0.96 899\n", + " macro avg 0.95 0.95 0.95 899\n", + "weighted avg 0.96 0.96 0.96 899\n", + "\n" + ] + } + ], + "source": [ + "print(additional_vars['classification_report'])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a6f5fe83", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 1.00 1.00 1.00 94\n", + " 1 0.91 0.93 0.92 85\n", + " 2 0.97 0.99 0.98 96\n", + " 3 0.99 0.97 0.98 93\n", + " 4 0.99 0.92 0.95 88\n", + " 5 0.95 0.95 0.95 85\n", + " 6 0.99 0.97 0.98 97\n", + " 7 0.97 0.97 0.97 89\n", + " 8 0.88 0.88 0.88 82\n", + " 9 0.91 0.97 0.94 90\n", + "\n", + " accuracy 0.96 899\n", + " macro avg 0.95 0.95 0.95 899\n", + "weighted avg 0.96 0.96 0.96 899\n", + "\n" + ] + } + ], + "source": [ + "print(open((materialization_results['classification_report_to_txt']['path'])).read())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/materialization/requirements.txt b/examples/materialization/requirements.txt new file mode 100644 index 000000000..3f69ad5c2 --- /dev/null +++ b/examples/materialization/requirements.txt @@ -0,0 +1,2 @@ +scikit-learn +sf-hamilton diff --git a/examples/materialization/run.py b/examples/materialization/run.py new file mode 100644 index 000000000..fce1faee7 --- /dev/null +++ b/examples/materialization/run.py @@ -0,0 +1,87 @@ +""" +Example script showing how one might setup a generic model training pipeline that is quickly configurable. +""" + +import importlib + +# Required import to register adapters +import os + +import data_loaders +import model_training + +from hamilton import base, driver +from hamilton.io.materialization import to + +# This has to be imported, but the linter doesn't like it cause its unused +# We just need to import it to register the materializers +importlib.import_module("custom_materializers") + + +def get_model_config(model_type: str) -> dict: + """Returns model type specific configuration""" + if model_type == "svm": + return {"clf": "svm", "gamma": 0.001} + elif model_type == "logistic": + return {"clf": "logistic", "penalty": "l2"} + else: + raise ValueError(f"Unsupported model {model_type}.") + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 3: + print("Error: required arguments are [iris|digits] [svm|logistic]") + sys.exit(1) + _data_set = sys.argv[1] # the data set to load + _model_type = sys.argv[2] # the model type to fit and evaluate with + + dag_config = { + "test_size_fraction": 0.5, + "shuffle_train_test_split": True, + } + if not os.path.exists("data"): + os.mkdir("data") + # augment config + dag_config.update(get_model_config(_model_type)) + dag_config["data_loader"] = _data_set + dr = ( + driver.Builder() + .with_adapter(base.DefaultAdapter()) + .with_config(dag_config) + .with_modules(data_loaders, model_training) + .build() + ) + materializers = [ + to.json( + dependencies=["model_parameters"], id="model_params_to_json", path="./data/params.json" + ), + # classification report to .txt file + to.file( + dependencies=["classification_report"], + id="classification_report_to_txt", + path="./data/classification_report.txt", + ), + # materialize the model to a pickle file + to.pickle(dependencies=["fit_clf"], id="clf_to_pickle", path="./data/clf.pkl"), + # materialize the predictions we made to a csv file + to.csv( + dependencies=["predicted_output_with_labels"], + id="predicted_output_with_labels_to_csv", + path="./data/predicted_output_with_labels.csv", + ), + ] + dr.visualize_materialization( + *materializers, + additional_vars=["classification_report"], + output_file_path="./dag", + render_kwargs={}, + ) + materialization_results, additional_vars = dr.materialize( + # materialize model parameters to json + *materializers, + additional_vars=["classification_report"], + ) + # print(materialization_results["classification_report"]) + # print(additional_vars) diff --git a/hamilton/driver.py b/hamilton/driver.py index 28e738b88..dab5f7d08 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -997,9 +997,9 @@ def materialize( def visualize_materialization( self, *materializers: materialization.MaterializerFactory, - additional_vars: List[Union[str, Callable, Variable]], output_file_path: str, render_kwargs: dict, + additional_vars: List[Union[str, Callable, Variable]] = None, inputs: Dict[str, Any] = None, graphviz_kwargs: dict = None, ) -> Optional["graphviz.Digraph"]: # noqa F821 @@ -1014,11 +1014,13 @@ def visualize_materialization( :param graphviz_kwargs: Arguments to pass to graphviz :return: The graphviz graph, if you want to do something with it """ + if additional_vars is None: + additional_vars = [] function_graph = materialization.modify_graph(self.graph, materializers) _final_vars = self._create_final_vars(additional_vars) + [ materializer.id for materializer in materializers ] - Driver._visualize_execution_helper( + return Driver._visualize_execution_helper( function_graph, self.adapter, _final_vars, @@ -1177,9 +1179,7 @@ def build(self) -> Driver: execution_manager = self.execution_manager if execution_manager is None: local_executor = self.local_executor or executors.SynchronousLocalTaskExecutor() - remote_executor = self.remote_executor or executors.MultiProcessingExecutor( - max_tasks=10 - ) + remote_executor = self.remote_executor or executors.MultiThreadingExecutor(max_tasks=10) execution_manager = executors.DefaultExecutionManager( local_executor=local_executor, remote_executor=remote_executor ) diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py index a5deb9ae3..333561516 100644 --- a/hamilton/function_modifiers/adapters.py +++ b/hamilton/function_modifiers/adapters.py @@ -1,6 +1,6 @@ import inspect import typing -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Collection, Dict, List, Tuple, Type from hamilton import node from hamilton.function_modifiers.base import ( @@ -139,7 +139,7 @@ def _select_param_to_inject(self, params: List[str], fn: Callable) -> str: def inject_nodes( self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable - ) -> Optional[Collection[node.Node]]: + ) -> Tuple[Collection[node.Node], Dict[str, str]]: pass """Generates two nodes: 1. A node that loads the data from the data source, and returns that + metadata @@ -217,7 +217,7 @@ def get_input_type_key(key: str) -> str: "hamilton.data_loader.classname": f"{loader_cls.__qualname__}", "hamilton.data_loader.node": inject_parameter, }, - namespace=("load_data", fn.__name__), + namespace=(fn.__name__, "load_data"), ) # the filter node is the node that takes the data from the data source, filters out @@ -239,8 +239,9 @@ def filter_function(_inject_parameter=inject_parameter, **kwargs): "hamilton.data_loader.classname": f"{loader_cls.__qualname__}", "hamilton.data_loader.node": inject_parameter, }, + namespace=(fn.__name__, "select_data"), ) - return [loader_node, filter_node] + return [loader_node, filter_node], {inject_parameter: filter_node.name} def _get_inject_parameter_from_function(self, fn: Callable) -> Tuple[str, Type[Type]]: """Gets the name of the parameter to inject the data into. @@ -311,7 +312,9 @@ def __getattr__(cls, item: str): f"Available loaders are: {LOADER_REGISTRY.keys()}. " f"If you've gotten to this point, you either (1) spelled the " f"loader name wrong, (2) are trying to use a loader that does" - f"not exist (yet)" + f"not exist (yet). For a list of available loaders, see: " + f"https://hamilton.readthedocs.io/reference/io/available-data-adapters/#data" + f"-loaders " ) from e @@ -424,11 +427,13 @@ def __getattr__(cls, item: str): return super().__getattribute__(item) except AttributeError as e: raise AttributeError( - f"No saver named: {item} available for {cls.__name__}. " - f"Available data savers are: {list(SAVER_REGISTRY.keys())}. " - f"If you've gotten to this point, you either (1) spelled the " - f"loader name wrong, (2) are trying to use a saver that does" - f"not exist (yet)." + "No saver named: {item} available for {cls.__name__}. " + "Available data savers are: {list(SAVER_REGISTRY.keys())}. " + "If you've gotten to this point, you either (1) spelled the " + "loader name wrong, (2) are trying to use a saver that does" + "not exist (yet). For a list of available savers, see " + "https://hamilton.readthedocs.io/reference/io/available-data-adapters/#data" + "-loaders " ) from e diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index 808999365..debbe37f9 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -10,7 +10,7 @@ except ImportError: # python3.10 and above EllipsisType = type(...) -from typing import Any, Callable, Collection, Dict, List, Optional, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union from hamilton import node, registry, settings @@ -228,6 +228,26 @@ def transform_dag( pass +# TODO -- delete this/replace with the version that will be added by +# https://github.com/DAGWorks-Inc/hamilton/pull/249/ as part of the Node class +def _reassign_input_names(node_: node.Node, input_names: Dict[str, Any]) -> node.Node: + """Reassigns the input names of a node. Useful for applying + a node to a separate input if needed. Note that things can get a + little strange if you have multiple inputs with the same name, so + be careful about how you use this. + :param input_names: Input name map to reassign + :return: A node with the input names reassigned + """ + + def new_callable(**kwargs) -> Any: + reverse_input_names = {v: k for k, v in input_names.items()} + return node_.callable(**{reverse_input_names.get(k, k): v for k, v in kwargs.items()}) + + new_input_types = {input_names.get(k, k): v for k, v in node_.input_types.items()} + out = node_.copy_with(callabl=new_callable, input_types=new_input_types) + return out + + class NodeInjector(SubDAGModifier, abc.ABC): """Injects a value as a source node in the DAG. This is a special case of the SubDAGModifier, which gets all the upstream (required) nodes from the subdag and gives the decorator a chance @@ -275,21 +295,37 @@ def transform_dag( :return: """ injectable_params = NodeInjector.find_injectable_params(nodes) - out = list(nodes) - out.extend(self.inject_nodes(injectable_params, config, fn)) + nodes_to_inject, rename_map = self.inject_nodes(injectable_params, config, fn) + out = [] + for node_ in nodes: + # if there's an intersection then we want to rename the input + if set(node_.input_types.keys()) & set(rename_map.keys()): + out.append(_reassign_input_names(node_, rename_map)) + else: + out.append(node_) + out.extend(nodes_to_inject) + if len(set([node_.name for node_ in out])) != len(out): + import pdb + + pdb.set_trace() + print([node_.name for node_ in out]) return out @abc.abstractmethod def inject_nodes( self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable - ) -> List[node.Node]: + ) -> Tuple[List[node.Node], Dict[str, str]]: """Adds a set of nodes to inject into the DAG. These get injected into the specified param name, - meaning that exactly one of the output nodes will have that name. + meaning that exactly one of the output nodes will have that name. Note that this also allows + input renaming, meaning that the injector can rename the input to something else (to avoid + name-clashes). :param params: Dictionary of all the type names one wants to inject :param config: Configuration with which the DAG was constructed. :param fn: original function we're decorating. This is useful largely for debugging. - :return: A list of nodes to add. Empty if you wish to inject nothing + :return: A list of nodes to add. Empty if you wish to inject nothing, as well as a dictionary, + allowing the injector to rename the inputs (e.g. if you want the name to be + namespaced to avoid clashes) """ pass diff --git a/hamilton/function_modifiers/dependencies.py b/hamilton/function_modifiers/dependencies.py index 1d3dcbd7f..575ed7188 100644 --- a/hamilton/function_modifiers/dependencies.py +++ b/hamilton/function_modifiers/dependencies.py @@ -29,7 +29,7 @@ class SingleDependency(ParametrizedDependency, abc.ABC): @dataclasses.dataclass -class LiteralDependency(ParametrizedDependency): +class LiteralDependency(SingleDependency): value: Any def get_dependency_type(self) -> ParametrizedDependencySource: @@ -37,7 +37,7 @@ def get_dependency_type(self) -> ParametrizedDependencySource: @dataclasses.dataclass -class UpstreamDependency(ParametrizedDependency): +class UpstreamDependency(SingleDependency): source: str def get_dependency_type(self) -> ParametrizedDependencySource: diff --git a/hamilton/graph.py b/hamilton/graph.py index 2f82fd645..57ec99ca1 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -68,19 +68,23 @@ def add_dependency( def update_dependencies( - nodes: Dict[str, node.Node], adapter: base.HamiltonGraphAdapter, in_place: bool = True + nodes: Dict[str, node.Node], adapter: base.HamiltonGraphAdapter, reset_dependencies: bool = True ): - """Adds dependecies to a dictionary of nodes. If in_place is False, + """Adds dependencies to a dictionary of nodes. If in_place is False, it will deepcopy the dict + nodes and return that. Otherwise it will mutate + return the passed-in dict + nodes. :param in_place: Whether or not to modify in-place, or copy/return :param nodes: Nodes that form the DAG we're updating :param adapter: Adapter to use for type checking + :param reset_dependencies: Whether or not to reset the dependencies. If they are not set this is + unnecessary, and we can save yet another pass. Note that `reset` will perform an in-place + operation. :return: The updated nodes """ - if not in_place: - nodes = {k: v for k, v in nodes.items()} + # copy without the dependencies to avoid duplicates + if reset_dependencies: + nodes = {k: v.copy(include_refs=False) for k, v in nodes.items()} for node_name, n in list(nodes.items()): for param_name, (param_type, _) in n.input_types.items(): add_dependency(n, node_name, nodes, param_name, param_type, adapter) @@ -118,7 +122,8 @@ def create_function_graph( ) nodes[n.name] = n # add dependencies -- now that all nodes exist, we just run through edges & validate graph. - update_dependencies(nodes, adapter) # in place + nodes = update_dependencies(nodes, adapter, reset_dependencies=False) # no dependencies + # present yet for key in config.keys(): if key not in nodes: nodes[key] = node.Node(key, Any, node_source=node.NodeType.EXTERNAL) diff --git a/hamilton/io/data_adapters.py b/hamilton/io/data_adapters.py index 965a1f20f..b997227e6 100644 --- a/hamilton/io/data_adapters.py +++ b/hamilton/io/data_adapters.py @@ -153,12 +153,12 @@ class DataSaver(AdapterCommon, abc.ABC): @abc.abstractmethod def save_data(self, data: Any) -> Dict[str, Any]: """Saves the data to the data source. - Note this uses the constructor parameters to determine - how to save the data. + Note this uses the constructor parameters to determine + how to save the data. :return: Any relevant metadata. This is up the the data saver, but will likely - include the URI, etc... This is going to be similar to the metadata returned - by the data loader in the loading tuple. + include the URI, etc... This is going to be similar to the metadata returned + by the data loader in the loading tuple. """ pass diff --git a/hamilton/io/materialization.py b/hamilton/io/materialization.py index 3688d7a6f..09952cb82 100644 --- a/hamilton/io/materialization.py +++ b/hamilton/io/materialization.py @@ -7,7 +7,7 @@ from hamilton.function_modifiers.dependencies import SingleDependency, value from hamilton.graph import FunctionGraph from hamilton.io.data_adapters import DataSaver -from hamilton.registry import LOADER_REGISTRY +from hamilton.registry import SAVER_REGISTRY class materialization_meta__(type): @@ -19,20 +19,22 @@ class in registry, or make it a function that just proxies to the decorator. We """ def __getattr__(cls, item: str): - if item in LOADER_REGISTRY: - potential_loaders = LOADER_REGISTRY[item] + if item in SAVER_REGISTRY: + potential_loaders = SAVER_REGISTRY[item] savers = [loader for loader in potential_loaders if issubclass(loader, DataSaver)] if len(savers) > 0: - return Materialize.partial(LOADER_REGISTRY[item]) + return Materialize.partial(SAVER_REGISTRY[item]) try: return super().__getattribute__(item) except AttributeError as e: raise AttributeError( - f"No loader named: {item} available for {cls.__name__}. " - f"Available loaders are: {LOADER_REGISTRY.keys()}. " - f"If you've gotten to this point, you either (1) spelled the " - f"loader name wrong, (2) are trying to use a loader that does" - f"not exist (yet)" + "No data materializer named: {item}. " + "Available materializers are: {SAVER_REGISTRY.keys()}. " + "If you've gotten to this point, you either (1) spelled the " + "loader name wrong, (2) are trying to use a loader that does" + "not exist (yet). For a list of available materializers, see " + "https://hamilton.readthedocs.io/reference/io/available-data-adapters/#data" + "-loaders " ) from e @@ -76,6 +78,7 @@ def _process_kwargs( """ processed_kwargs = {} for kwarg, kwarg_val in data_saver_kwargs.items(): + if not isinstance(kwarg_val, SingleDependency): processed_kwargs[kwarg] = value(kwarg_val) else: diff --git a/hamilton/node.py b/hamilton/node.py index 50f4a5c9b..3bba90aa6 100644 --- a/hamilton/node.py +++ b/hamilton/node.py @@ -108,7 +108,7 @@ def __init__( DependencyType.from_parameter(value), ) elif self.user_defined: - if input_types is not None: + if len(self._input_types) > 0: raise ValueError( f"Input types cannot be provided for user-defined node {self.name}" ) @@ -266,11 +266,12 @@ def from_fn(fn: Callable, name: str = None) -> "Node": node_source=node_source, ) - def copy_with(self, **overrides) -> "Node": + def copy_with(self, include_refs: bool = True, **overrides) -> "Node": """Copies a node with the specified overrides for the constructor arguments. Utility function for creating a node -- useful for modifying it. :param kwargs: kwargs to use in place of the node. Passed to the constructor. + :param include_refs: Whether or not to include dependencies and depended_on_by :return: A node copied from self with the specified keyword arguments replaced. """ constructor_args = dict( @@ -284,4 +285,20 @@ def copy_with(self, **overrides) -> "Node": originating_functions=self.originating_functions, ) constructor_args.update(**overrides) - return Node(**constructor_args) + out = Node(**constructor_args) + if include_refs: + out._dependencies = self._dependencies + out._depended_on_by = self._depended_on_by + return out + + def copy(self, include_refs: bool = True) -> "Node": + """Copies a node, not modifying anything (except for the references + /dependencies if specified). + + :param include_refs: Whether or not to include dependencies and depended_on_by + :return: A copy of the node. + """ + """Gives a copy of the node, so we can modify it without modifying the original. + :return: A copy of the node. + """ + return self.copy_with(include_refs) diff --git a/requirements-docs.txt b/requirements-docs.txt index 36137690b..2403cff5d 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -3,6 +3,7 @@ alabaster>=0.7,<0.8,!=0.7.5 # read the docs pins commonmark==0.9.1 # read the docs pins dask[distributed] furo +gitpython # Required for parsing git info for generation of data-adapter docs mock==1.0.1 # read the docs pins myst-parser==0.18.1 # latest version of myst at this time pillow diff --git a/tests/function_modifiers/test_adapters.py b/tests/function_modifiers/test_adapters.py index 49cbda1f6..cba3470f2 100644 --- a/tests/function_modifiers/test_adapters.py +++ b/tests/function_modifiers/test_adapters.py @@ -75,17 +75,17 @@ def fn(data: int) -> int: nodes_by_name = {node_.name: node_ for node_ in nodes} assert len(nodes_by_name) == 3 assert "fn" in nodes_by_name - assert nodes_by_name["data"].tags == { + assert nodes_by_name["fn.load_data.data"].tags == { "hamilton.data_loader.source": "mock", "hamilton.data_loader": True, - "hamilton.data_loader.has_metadata": False, + "hamilton.data_loader.has_metadata": True, "hamilton.data_loader.node": "data", "hamilton.data_loader.classname": MockDataLoader.__qualname__, } - assert nodes_by_name["load_data.fn.data"].tags == { + assert nodes_by_name["fn.select_data.data"].tags == { "hamilton.data_loader.source": "mock", "hamilton.data_loader": True, - "hamilton.data_loader.has_metadata": True, + "hamilton.data_loader.has_metadata": False, "hamilton.data_loader.node": "data", "hamilton.data_loader.classname": MockDataLoader.__qualname__, } @@ -333,7 +333,7 @@ def fn_str_inject(injected_data: str) -> str: ) result = fg.execute(inputs={}, nodes=fg.nodes.values()) assert result["fn_str_inject"] == "foo" - assert result["load_data.fn_str_inject.injected_data"] == ( + assert result["fn_str_inject.load_data.injected_data"] == ( "foo", {"loader": "string_data_loader"}, ) @@ -362,12 +362,12 @@ def fn_str_inject(injected_data_1: str, injected_data_2: int) -> str: ) result = fg.execute(inputs={}, nodes=fg.nodes.values()) assert result["fn_str_inject"] == "foofoo" - assert result["load_data.fn_str_inject.injected_data_1"] == ( + assert result["fn_str_inject.load_data.injected_data_1"] == ( "foo", {"loader": "string_data_loader"}, ) - assert result["load_data.fn_str_inject.injected_data_2"] == ( + assert result["fn_str_inject.load_data.injected_data_2"] == ( 2, {"loader": "int_data_loader_2"}, ) diff --git a/tests/test_graph.py b/tests/test_graph.py index 0431b0893..96d908506 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -220,15 +220,6 @@ def test_add_dependency_user_nodes(): assert func_node.depended_on_by == [] -def test_create_function_graph_simple(): - """Tests that we create a simple function graph.""" - expected = create_testing_nodes() - actual = graph.create_function_graph( - tests.resources.dummy_functions, config={}, adapter=base.SimplePythonDataFrameGraphAdapter() - ) - assert actual == expected - - def create_testing_nodes(): """Helper function for creating the nodes represented in dummy_functions.py.""" nodes = { @@ -275,6 +266,15 @@ def create_testing_nodes(): return nodes +def test_create_function_graph_simple(): + """Tests that we create a simple function graph.""" + expected = create_testing_nodes() + actual = graph.create_function_graph( + tests.resources.dummy_functions, config={}, adapter=base.SimplePythonDataFrameGraphAdapter() + ) + assert actual == expected + + def test_execute(): """Tests graph execution along with basic memoization since A is depended on by two functions.""" adapter = base.SimplePythonDataFrameGraphAdapter() @@ -800,3 +800,11 @@ def my_function(A: int, b: int, c: int) -> int: ) results = fg.execute([n for n in fg.get_nodes() if n.name in ["my_function", "A"]]) assert results == {"A": 4, "b": 3, "c": 1, "my_function": 8} + + +def test_update_dependencies(): + nodes = create_testing_nodes() + new_nodes = graph.update_dependencies(nodes, base.DefaultAdapter()) + for node_name, node_ in new_nodes.items(): + assert node_.dependencies == nodes[node_name].dependencies + assert node_.depended_on_by == nodes[node_name].depended_on_by