diff --git a/docs/conf.py b/docs/conf.py
index ea274db7d..35f790006 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -30,6 +30,7 @@
"sphinx.ext.autosummary",
"myst_parser",
"sphinx_sitemap",
+ "docs.data_adapters_extension",
]
# for the sitemap extension ---
diff --git a/docs/data_adapters_extension.py b/docs/data_adapters_extension.py
new file mode 100644
index 000000000..923c78cfb
--- /dev/null
+++ b/docs/data_adapters_extension.py
@@ -0,0 +1,288 @@
+import dataclasses
+import inspect
+import os
+from typing import List, Optional, Tuple, Type
+
+import git
+from docutils import nodes
+from docutils.parsers.rst import Directive
+
+import hamilton.io.data_adapters
+from hamilton import registry
+
+"""A module to crawl available data adapters and generate documentation for them.
+Note these currently link out to the source code on GitHub, but they should
+be linking to the documentation instead, which hasn't been generated yet.
+"""
+
+# These have fallbacks for local dev
+GIT_URL = os.environ.get("READTHEDOCS_GIT_CLONE_URL", "https://github.com/dagworks-inc/hamilton")
+GIT_ID = os.environ.get("READTHEDOCS_GIT_IDENTIFIER", "main")
+
+# All the modules that register data adapters
+# When you register a new one, add it here
+MODULES_TO_IMPORT = ["hamilton.io.default_data_loaders", "hamilton.plugins.pandas_extensions"]
+
+for module in MODULES_TO_IMPORT:
+ __import__(module)
+
+
+def get_git_root(path: str) -> str:
+ """Yields the git room of a repo, given an absolute path to
+ a file within the repo.
+
+ :param path: Path to a file within a git repo
+ :return: The root of the git repo
+ """
+ git_repo = git.Repo(path, search_parent_directories=True)
+ git_root = git_repo.git.rev_parse("--show-toplevel")
+ return git_root
+
+
+@dataclasses.dataclass
+class Param:
+ name: str
+ type: str
+ default: Optional[str] = None
+
+
+def get_default(param: dataclasses.Field) -> Optional[str]:
+ """Gets the deafult of a dataclass field, if it has one.
+
+ :param param: The dataclass field
+ :return: The str representation of the default.
+ """
+ if param.default is dataclasses.MISSING:
+ return None
+ return str(param.default)
+
+
+def get_lines_for_class(class_: Type[Type]) -> Tuple[int, int]:
+ """Gets the set of lines in which a class is implemented
+
+ :param class_: The class to get the lines for
+ :return: A tuple of the start and end lines
+ """
+ lines = inspect.getsourcelines(class_)
+ start_line = lines[1]
+ end_line = lines[1] + len(lines[0])
+ return start_line, end_line
+
+
+def get_class_repr(class_: Type) -> str:
+ """Gets a representation of a class that can be used in documentation.
+
+ :param class_: Python class to get the representation for
+ :return: Str representation
+ """
+
+ try:
+ return class_.__qualname__
+ except AttributeError:
+ # This happens when we have generics or other oddities
+ return str(class_)
+
+
+@dataclasses.dataclass
+class AdapterInfo:
+ key: str
+ class_name: str
+ class_path: str
+ load_params: List[Param]
+ save_params: List[Param]
+ applicable_types: List[str]
+ file_: str
+ line_nos: Tuple[int, int]
+
+ @staticmethod
+ def from_loader(loader: Type[hamilton.io.data_adapters.DataLoader]) -> "AdapterInfo":
+ """Utility constructor to create the AdapterInfo from a DataLoader class
+
+ :param loader: DataLoader class
+ :return: AdapterInfo derived from it
+ """
+
+ return AdapterInfo(
+ key=loader.name(),
+ class_name=loader.__name__,
+ class_path=loader.__module__,
+ load_params=[
+ Param(name=p.name, type=get_class_repr(p.type), default=get_default(p))
+ for p in dataclasses.fields(loader)
+ ]
+ if issubclass(loader, hamilton.io.data_adapters.DataSaver)
+ else None,
+ save_params=[
+ Param(name=p.name, type=get_class_repr(p.type), default=get_default(p))
+ for p in dataclasses.fields(loader)
+ ]
+ if issubclass(loader, hamilton.io.data_adapters.DataSaver)
+ else None,
+ applicable_types=[get_class_repr(t) for t in loader.applicable_types()],
+ file_=inspect.getfile(loader),
+ line_nos=get_lines_for_class(loader),
+ )
+
+
+def _collect_loaders(saver_or_loader: str) -> List[Type[hamilton.io.data_adapters.AdapterCommon]]:
+ """Collects all loaders from the registry.
+
+ :return:
+ """
+ out = []
+ loaders = (
+ list(registry.LOADER_REGISTRY.values())
+ if saver_or_loader == "loader"
+ else list(registry.SAVER_REGISTRY.values())
+ )
+ for classes in loaders:
+ for cls in classes:
+ if cls not in out:
+ out.append(cls)
+ return out
+
+
+# Utility functions to render different components of the adapter in table cells
+
+
+def render_key(key: str):
+ return [nodes.Text(key, key)]
+
+
+def render_class_name(class_name: str):
+ return [nodes.literal(text=class_name)]
+
+
+def render_class_path(class_path: str, file_: str, line_start: int, line_end: int):
+ git_path = get_git_root(file_)
+ file_relative_to_git_root = os.path.relpath(file_, git_path)
+ href = f"{GIT_URL}/blob/{GIT_ID}/{file_relative_to_git_root}#L{line_start}-L{line_end}"
+ # href = f"{GIT_URL}/blob/{GIT_ID}/{file_}#L{line_no}"
+ return [nodes.raw("", f'{class_path}', format="html")]
+
+
+def render_adapter_params(load_params: Optional[List[Param]]):
+ if load_params is None:
+ return nodes.raw("", "
", format="html")
+ fieldlist = nodes.field_list()
+ for i, load_param in enumerate(load_params):
+ fieldname = nodes.Text(load_param.name)
+ fieldbody = nodes.literal(
+ text=load_param.type
+ + ("=" + load_param.default if load_param.default is not None else "")
+ )
+ field = nodes.field("", fieldname, fieldbody)
+ fieldlist += field
+ if i < len(load_params) - 1:
+ fieldlist += nodes.raw("", " ", format="html")
+ return fieldlist
+
+
+def render_applicable_types(applicable_types: List[str]):
+ fieldlist = nodes.field_list()
+ for applicable_type in applicable_types:
+ fieldlist += nodes.field("", nodes.literal(text=applicable_type), nodes.Text(""))
+ fieldlist += nodes.raw("", " ", format="html")
+ return fieldlist
+
+
+class DataAdapterTableDirective(Directive):
+ """Custom directive to render a table of all data adapters. Takes in one argument
+ that is either 'loader' or 'saver' to indicate which adapters to render."""
+
+ has_content = True
+ required_arguments = 1 # Number of required arguments
+
+ def run(self):
+ """Runs the directive. This does the following:
+ 1. Collects all loaders from the registry
+ 2. Creates a table with the following columns:
+ - Key
+ - Class name
+ - Class path
+ - Load params
+ - Applicable types
+ 3. Returns the table
+ :return: A list of nodes that Sphinx will render, consisting of the table node
+ """
+ saver_or_loader = self.arguments[0]
+ if saver_or_loader not in ("loader", "saver"):
+ raise ValueError(
+ f"loader_or_saver must be one of 'loader' or 'saver', " f"got {saver_or_loader}"
+ )
+ table_data = [
+ AdapterInfo.from_loader(loader) for loader in _collect_loaders(saver_or_loader)
+ ]
+
+ # Create the table and add columns
+ table_node = nodes.table()
+ tgroup = nodes.tgroup(cols=6)
+ table_node += tgroup
+
+ # Create columns
+ key_spec = nodes.colspec(colwidth=1)
+ # class_spec = nodes.colspec(colwidth=1)
+ load_params_spec = nodes.colspec(colwidth=2)
+ applicable_types_spec = nodes.colspec(colwidth=1)
+ class_path_spec = nodes.colspec(colwidth=1)
+
+ tgroup += [key_spec, load_params_spec, applicable_types_spec, class_path_spec]
+
+ # Create the table body
+ thead = nodes.thead()
+ row = nodes.row()
+
+ # Create entry nodes for each cell
+ key_entry = nodes.entry()
+ load_params_entry = nodes.entry()
+ applicable_types_entry = nodes.entry()
+ class_path_entry = nodes.entry()
+
+ key_entry += nodes.paragraph(text="key")
+
+ load_params_entry += nodes.paragraph(text=f"{saver_or_loader} params")
+ applicable_types_entry += nodes.paragraph(text="types")
+ class_path_entry += nodes.paragraph(text="module")
+
+ row += [key_entry, load_params_entry, applicable_types_entry, class_path_entry]
+ thead += row
+ tgroup += thead
+ tbody = nodes.tbody()
+ tgroup += tbody
+
+ # Populate table rows based on your table_data
+ for row_data in table_data:
+ row = nodes.row()
+
+ # Create entry nodes for each cell
+ key_entry = nodes.entry()
+ load_params_entry = nodes.entry()
+ applicable_types_entry = nodes.entry()
+ class_path_entry = nodes.entry()
+
+ # Create a paragraph node for each entry
+ # import pdb
+ # pdb.set_trace()
+ # para1 = nodes.literal(text=row_data['column1_data'])
+ # para2 = nodes.paragraph(text=row_data['column2_data'])
+
+ # Add the paragraph nodes to the entry nodes
+ key_entry += render_key(row_data.key)
+ load_params_entry += render_adapter_params(row_data.load_params)
+ applicable_types_entry += render_applicable_types(row_data.applicable_types)
+ class_path_entry += render_class_path(
+ row_data.class_path, row_data.file_, *row_data.line_nos
+ )
+
+ # Add the entry nodes to the row
+ row += [key_entry, load_params_entry, applicable_types_entry, class_path_entry]
+
+ # Add the row to the table body
+ tbody += row
+
+ return [table_node]
+
+
+def setup(app):
+ """Required to register the extension"""
+ app.add_directive("data_adapter_table", DataAdapterTableDirective)
diff --git a/docs/index.md b/docs/index.md
index b82c522e8..ee17efafe 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -41,6 +41,7 @@ contributing
reference/decorators/index
reference/drivers/index
+reference/io/index
reference/graph-adapters/index
reference/result-builders/index
reference/miscellaneous/index
diff --git a/docs/reference/io/adapter-documentation.rst b/docs/reference/io/adapter-documentation.rst
new file mode 100644
index 000000000..6394a4c02
--- /dev/null
+++ b/docs/reference/io/adapter-documentation.rst
@@ -0,0 +1,20 @@
+=========================
+Data Adapters
+=========================
+
+Reference for data adapter base classes:
+
+.. autoclass:: hamilton.io.data_adapters.DataLoader
+ :special-members: __init__
+ :members:
+ :inherited-members:
+
+.. autoclass:: hamilton.io.data_adapters.DataSaver
+ :special-members: __init__
+ :members:
+ :inherited-members:
+
+.. autoclass:: hamilton.io.data_adapters.AdapterCommon
+ :special-members: __init__
+ :members:
+ :inherited-members:
diff --git a/docs/reference/io/available-data-adapters.rst b/docs/reference/io/available-data-adapters.rst
new file mode 100644
index 000000000..a8c5a2469
--- /dev/null
+++ b/docs/reference/io/available-data-adapters.rst
@@ -0,0 +1,56 @@
+========================
+Using Data Adapters
+========================
+
+This is an index of all the available data adapters, both savers and loaders.
+Note that some savers and loaders are the same (certain classes can handle both),
+but some are different. You will want to reference this when calling out to any of the following:
+
+1. Using :doc:`/reference/decorators/save_to/`.
+2. Using :doc:`/reference/decorators/load_from/`.
+3. Using :doc:`materialize `
+
+To read these tables, you want to first look at the key to determine which format you want --
+these should be human-readable and familiar to you. Then you'll want to look at the `types` field
+to figure out which is the best for your case (the object you want to load from or save to).
+
+Finally, look up the adapter params to see what parameters you can pass to the data adapters.
+The optional params come with their default value specified.
+
+If you want more information, click on the `module`, it will send you to the code that implements
+it to see how the parameters are used.
+
+As an example, say we wanted to save a pandas dataframe to a CSV file. We would first find the
+key `csv`, which would inform us that we want to call `save_to.csv` (or `to.csv` in the case
+of `materialize`). Then, we would look at the `types` field, finding that there is a pandas
+dataframe adapter. Finally, we would look at the `params` field, finding that we can pass
+`path`, and (optionally) `sep` (which we'd realize defaults to `,` when looking at the code).
+
+All together, we'd end up with:
+
+.. code-block:: python
+
+ import pandas as pd
+ from hamilton.function_modifiers import value, save_to
+
+ @save_to.csv(path=value("my_file.csv"))
+ def my_data(...) -> pd.DataFrame:
+ ...
+
+And we're good to go!
+
+If you want to extend these, see :doc:`/reference/io/available-data-adapters` for documentation,
+and `the example `_
+in the repository for an example of how to do so.
+
+=============
+Data Loaders
+=============
+
+.. data_adapter_table:: loader
+
+=============
+Data Savers
+=============
+
+.. data_adapter_table:: saver
diff --git a/docs/reference/io/index.rst b/docs/reference/io/index.rst
new file mode 100644
index 000000000..637a69c03
--- /dev/null
+++ b/docs/reference/io/index.rst
@@ -0,0 +1,11 @@
+==============
+I/O
+==============
+
+This section contains any information about I/O within Hamilton
+
+.. toctree::
+ :maxdepth: 2
+
+ available-data-adapters
+ adapter-documentation
diff --git a/examples/materialization/README.md b/examples/materialization/README.md
new file mode 100644
index 000000000..3a6c73e8d
--- /dev/null
+++ b/examples/materialization/README.md
@@ -0,0 +1,66 @@
+# Materialization
+
+Hamilton's driver allows for ad-hoc materialization. This enables you to take a DAG you already have,
+and save your data to a set of custom locations/url.
+
+Note that these materializers are _isomorphic_ in nature to the
+[@save_to](https://hamilton.dagworks.io/en/latest/reference/decorators/save_to/)
+decorator. Materializers inject the additional node at runtime, modifying the
+DAG to include a data saver node, and returning the metadata around materialization.
+
+This framework is meant to be highly pluggable. While the set of available data savers is currently
+limited, we expect folks to build their own materializers (and, hopefully, contribute them back to the community!).
+
+
+## example
+In this example we take the scikit-learn iris_loader pipeline, and materialize outputs to specific
+locations through a driver call. We demonstrate:
+
+1. Saving model parameters to a json file (using the default json materializer)
+2. Writing a custom data adapters for:
+ 1. Pickling a model to an object file
+ 2. Saving confusion matrices to a csv file
+
+See [run.py](run.py) for the full example.
+
+In this example we only pass literal values to the materializers. That said, you can use both `source` (to specify the source from an upstream node),
+and `value` (which is the default) to specify literals.
+
+
+## `driver.materialize`
+
+This will be a high-level overview. For more details,
+see [documentation](https://hamilton.dagworks.io/en/latest/reference/drivers/Driver/#hamilton.driver.Driver.materializehttps://hamilton.dagworks.io/en/latest/reference/drivers/Driver/#hamilton.driver.Driver.materialize).
+
+`driver.materialize()` does the following:
+1. Processes a list of materializers to create a new DAG
+2. Alters the output to include the materializer nodes
+3. Processes a list of "additional variables" (for debugging) to return intermediary data
+4. Executes the DAG, including the materializers
+5. Returns a tuple of (`materialization metadata`, `additional variables`)
+
+Materializers each consume:
+1. A `dependencies` list to materialize
+2. A (optional) `combine` parameter to combine the outputs of the dependencies
+(this is required if there are multiple dependencies). This is a [ResultMixin](https://hamilton.dagworks.io/en/latest/concepts/customizing-execution/#result-builders) object
+3. an `id` parameter to identify the materializer, which serves as the nde name in the DAG
+
+Materializers are referenced by the `to` object in `hamilton.io.materialization`, which utilizes
+dynamic dispatch to create the appropriate materializer.
+
+These refer to a `DataSaver`, which are keyed by a string (E.G `csv`).
+Multiple data adapters can share the same key, each of which applies to a specific type
+(E.G. pandas dataframe, numpy matrix, polars dataframe). New
+data adapters are registered by calling `hamilton.registry.register_adapter`
+
+## Custom Materializers
+
+To define a custom materializer, all you have to do is implement the `DataSaver` class
+(which will allow use in `save_to` as well.) This is demonstrated in [custom_materializers.py](custom_materializers.py).
+
+## `driver.materialize` vs `@save_to`
+
+`driver.materialize` is an ad-hoc form of `save_to`. You want to use this when you're developing, and
+want to do ad-hoc materialization. When you have a production ETL, you can choose between `save_to` and `materialize`.
+If the save location/structure is unlikely to change, then you might consider using `save_to`. Otherwise, `materialize`
+is an idiomatic way of conducting the maerialization operations that cleanly separates side-effects from transformations.
diff --git a/examples/materialization/custom_materializers.py b/examples/materialization/custom_materializers.py
new file mode 100644
index 000000000..a2508ae43
--- /dev/null
+++ b/examples/materialization/custom_materializers.py
@@ -0,0 +1,55 @@
+import dataclasses
+import pickle
+from typing import Any, Collection, Dict, Type
+
+import numpy as np
+from sklearn import base
+
+from hamilton import registry
+from hamilton.io import utils
+from hamilton.io.data_adapters import DataSaver
+
+# TODO -- put this back in the standard library
+
+
+@dataclasses.dataclass
+class NumpyMatrixToCSV(DataSaver):
+ path: str
+ sep: str = ","
+
+ def __post_init__(self):
+ if not self.path.endswith(".csv"):
+ raise ValueError(f"CSV files must end with .csv, got {self.path}")
+
+ def save_data(self, data: np.ndarray) -> Dict[str, Any]:
+ np.savetxt(self.path, data, delimiter=self.sep)
+ return utils.get_file_metadata(self.path)
+
+ @classmethod
+ def applicable_types(cls) -> Collection[Type]:
+ return [np.ndarray]
+
+ @classmethod
+ def name(cls) -> str:
+ return "csv"
+
+
+@dataclasses.dataclass
+class SKLearnPickler(DataSaver):
+ path: str
+
+ def save_data(self, data: base.ClassifierMixin) -> Dict[str, Any]:
+ pickle.dump(data, open(self.path, "wb"))
+ return utils.get_file_metadata(self.path)
+
+ @classmethod
+ def applicable_types(cls) -> Collection[Type]:
+ return [base.ClassifierMixin]
+
+ @classmethod
+ def name(cls) -> str:
+ return "pickle"
+
+
+for adapter in [NumpyMatrixToCSV, SKLearnPickler]:
+ registry.register_adapter(adapter)
diff --git a/examples/materialization/dag.pdf b/examples/materialization/dag.pdf
new file mode 100644
index 000000000..3d123bb07
Binary files /dev/null and b/examples/materialization/dag.pdf differ
diff --git a/examples/materialization/data_loaders.py b/examples/materialization/data_loaders.py
new file mode 100644
index 000000000..e31061eab
--- /dev/null
+++ b/examples/materialization/data_loaders.py
@@ -0,0 +1,30 @@
+import numpy as np
+from sklearn import datasets, utils
+
+from hamilton.function_modifiers import config
+
+"""
+Module to load digit data.
+"""
+
+
+@config.when(data_loader="iris")
+def data__iris() -> utils.Bunch:
+ return datasets.load_digits()
+
+
+@config.when(data_loader="digits")
+def data__digits() -> utils.Bunch:
+ return datasets.load_digits()
+
+
+def target(data: utils.Bunch) -> np.ndarray:
+ return data.target
+
+
+def target_names(data: utils.Bunch) -> np.ndarray:
+ return data.target_names
+
+
+def feature_matrix(data: utils.Bunch) -> np.ndarray:
+ return data.data
diff --git a/examples/materialization/model_training.py b/examples/materialization/model_training.py
new file mode 100644
index 000000000..894cc533a
--- /dev/null
+++ b/examples/materialization/model_training.py
@@ -0,0 +1,96 @@
+from typing import Dict
+
+import numpy as np
+from sklearn import base, linear_model, metrics, svm
+from sklearn.model_selection import train_test_split
+
+from hamilton import function_modifiers
+
+
+@function_modifiers.config.when(clf="svm")
+def prefit_clf__svm(gamma: float = 0.001) -> base.ClassifierMixin:
+ """Returns an unfitted SVM classifier object.
+
+ :param gamma: ...
+ :return:
+ """
+ return svm.SVC(gamma=gamma)
+
+
+@function_modifiers.config.when(clf="logistic")
+def prefit_clf__logreg(penalty: str) -> base.ClassifierMixin:
+ """Returns an unfitted Logistic Regression classifier object.
+
+ :param penalty:
+ :return:
+ """
+ return linear_model.LogisticRegression(penalty)
+
+
+@function_modifiers.extract_fields(
+ {"X_train": np.ndarray, "X_test": np.ndarray, "y_train": np.ndarray, "y_test": np.ndarray}
+)
+def train_test_split_func(
+ feature_matrix: np.ndarray,
+ target: np.ndarray,
+ test_size_fraction: float,
+ shuffle_train_test_split: bool,
+) -> Dict[str, np.ndarray]:
+ """Function that creates the training & test splits.
+
+ It this then extracted out into constituent components and used downstream.
+
+ :param feature_matrix:
+ :param target:
+ :param test_size_fraction:
+ :param shuffle_train_test_split:
+ :return:
+ """
+ X_train, X_test, y_train, y_test = train_test_split(
+ feature_matrix, target, test_size=test_size_fraction, shuffle=shuffle_train_test_split
+ )
+ return {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test}
+
+
+def y_test_with_labels(y_test: np.ndarray, target_names: np.ndarray) -> np.ndarray:
+ """Adds labels to the target output."""
+ return np.array([target_names[idx] for idx in y_test])
+
+
+def fit_clf(
+ prefit_clf: base.ClassifierMixin, X_train: np.ndarray, y_train: np.ndarray
+) -> base.ClassifierMixin:
+ """Calls fit on the classifier object; it mutates it."""
+ prefit_clf.fit(X_train, y_train)
+ return prefit_clf
+
+
+def predicted_output(fit_clf: base.ClassifierMixin, X_test: np.ndarray) -> np.ndarray:
+ """Exercised the fit classifier to perform a prediction."""
+ return fit_clf.predict(X_test)
+
+
+def predicted_output_with_labels(
+ predicted_output: np.ndarray, target_names: np.ndarray
+) -> np.ndarray:
+ """Replaces the predictions with the desired labels."""
+ return np.array([target_names[idx] for idx in predicted_output])
+
+
+def classification_report(
+ predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray
+) -> str:
+ """Returns a classification report."""
+ return metrics.classification_report(y_test_with_labels, predicted_output_with_labels)
+
+
+def confusion_matrix(
+ predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray
+) -> str:
+ """Returns a confusion matrix report."""
+ return metrics.confusion_matrix(y_test_with_labels, predicted_output_with_labels)
+
+
+def model_parameters(fit_clf: base.ClassifierMixin) -> dict:
+ """Returns a dictionary of model parameters."""
+ return fit_clf.get_params()
diff --git a/examples/materialization/notebook.ipynb b/examples/materialization/notebook.ipynb
new file mode 100644
index 000000000..0a414930a
--- /dev/null
+++ b/examples/materialization/notebook.ipynb
@@ -0,0 +1,535 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "7bf6a40d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import os\n",
+ "\n",
+ "import data_loaders\n",
+ "import model_training\n",
+ "\n",
+ "from hamilton import base, driver\n",
+ "from hamilton.io.materialization import to\n",
+ "import pandas as pd\n",
+ "\n",
+ "import custom_materializers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "7a449245",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dag_config = {\n",
+ " \"test_size_fraction\": 0.5,\n",
+ " \"shuffle_train_test_split\": True,\n",
+ " \"data_loader\" : \"iris\",\n",
+ " \"clf\" : \"logistic\",\n",
+ " \"penalty\" : \"l2\"\n",
+ "}\n",
+ "dr = (\n",
+ " driver.Builder()\n",
+ " .with_adapter(base.DefaultAdapter())\n",
+ " .with_config(dag_config)\n",
+ " .with_modules(data_loaders, model_training)\n",
+ " .build()\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "397b09bc",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "materializers = [\n",
+ " to.json(\n",
+ " dependencies=[\"model_parameters\"],\n",
+ " id=\"model_params_to_json\",\n",
+ " path=\"./data/params.json\"\n",
+ " ),\n",
+ " # classification report to .txt file\n",
+ " to.file(\n",
+ " dependencies=[\"classification_report\"],\n",
+ " id=\"classification_report_to_txt\",\n",
+ " path=\"./data/classification_report.txt\",\n",
+ " ),\n",
+ " # materialize the model to a pickle file\n",
+ " to.pickle(\n",
+ " dependencies=[\"fit_clf\"], id=\"clf_to_pickle\", path=\"./data/clf.pkl\"\n",
+ " ),\n",
+ " # materialize the predictions we made to a csv file\n",
+ " to.csv(\n",
+ " dependencies=[\"predicted_output_with_labels\"],\n",
+ " id=\"predicted_output_with_labels_to_csv\",\n",
+ " path=\"./data/predicted_output_with_labels.csv\",\n",
+ " ),\n",
+ " ]\n",
+ "\n",
+ "dr.visualize_materialization(\n",
+ " *materializers,\n",
+ " additional_vars=[\"classification_report\"],\n",
+ " output_file_path=None,\n",
+ " render_kwargs={},\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "f5727b54",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/elijahbenizzy/.pyenv/versions/3.9.10/envs/hamilton/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
+ "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
+ "\n",
+ "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
+ " https://scikit-learn.org/stable/modules/preprocessing.html\n",
+ "Please also refer to the documentation for alternative solver options:\n",
+ " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
+ " n_iter_i = _check_optimize_result(\n"
+ ]
+ }
+ ],
+ "source": [
+ "materialization_results, additional_vars = dr.materialize(\n",
+ " # materialize model parameters to json\n",
+ " *materializers,\n",
+ " additional_vars=[\"classification_report\"],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "8bdfde70",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 1.00 1.00 1.00 94\n",
+ " 1 0.91 0.93 0.92 85\n",
+ " 2 0.97 0.99 0.98 96\n",
+ " 3 0.99 0.97 0.98 93\n",
+ " 4 0.99 0.92 0.95 88\n",
+ " 5 0.95 0.95 0.95 85\n",
+ " 6 0.99 0.97 0.98 97\n",
+ " 7 0.97 0.97 0.97 89\n",
+ " 8 0.88 0.88 0.88 82\n",
+ " 9 0.91 0.97 0.94 90\n",
+ "\n",
+ " accuracy 0.96 899\n",
+ " macro avg 0.95 0.95 0.95 899\n",
+ "weighted avg 0.96 0.96 0.96 899\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(additional_vars['classification_report'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "a6f5fe83",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 1.00 1.00 1.00 94\n",
+ " 1 0.91 0.93 0.92 85\n",
+ " 2 0.97 0.99 0.98 96\n",
+ " 3 0.99 0.97 0.98 93\n",
+ " 4 0.99 0.92 0.95 88\n",
+ " 5 0.95 0.95 0.95 85\n",
+ " 6 0.99 0.97 0.98 97\n",
+ " 7 0.97 0.97 0.97 89\n",
+ " 8 0.88 0.88 0.88 82\n",
+ " 9 0.91 0.97 0.94 90\n",
+ "\n",
+ " accuracy 0.96 899\n",
+ " macro avg 0.95 0.95 0.95 899\n",
+ "weighted avg 0.96 0.96 0.96 899\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(open((materialization_results['classification_report_to_txt']['path'])).read())"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/materialization/requirements.txt b/examples/materialization/requirements.txt
new file mode 100644
index 000000000..3f69ad5c2
--- /dev/null
+++ b/examples/materialization/requirements.txt
@@ -0,0 +1,2 @@
+scikit-learn
+sf-hamilton
diff --git a/examples/materialization/run.py b/examples/materialization/run.py
new file mode 100644
index 000000000..fce1faee7
--- /dev/null
+++ b/examples/materialization/run.py
@@ -0,0 +1,87 @@
+"""
+Example script showing how one might setup a generic model training pipeline that is quickly configurable.
+"""
+
+import importlib
+
+# Required import to register adapters
+import os
+
+import data_loaders
+import model_training
+
+from hamilton import base, driver
+from hamilton.io.materialization import to
+
+# This has to be imported, but the linter doesn't like it cause its unused
+# We just need to import it to register the materializers
+importlib.import_module("custom_materializers")
+
+
+def get_model_config(model_type: str) -> dict:
+ """Returns model type specific configuration"""
+ if model_type == "svm":
+ return {"clf": "svm", "gamma": 0.001}
+ elif model_type == "logistic":
+ return {"clf": "logistic", "penalty": "l2"}
+ else:
+ raise ValueError(f"Unsupported model {model_type}.")
+
+
+if __name__ == "__main__":
+ import sys
+
+ if len(sys.argv) < 3:
+ print("Error: required arguments are [iris|digits] [svm|logistic]")
+ sys.exit(1)
+ _data_set = sys.argv[1] # the data set to load
+ _model_type = sys.argv[2] # the model type to fit and evaluate with
+
+ dag_config = {
+ "test_size_fraction": 0.5,
+ "shuffle_train_test_split": True,
+ }
+ if not os.path.exists("data"):
+ os.mkdir("data")
+ # augment config
+ dag_config.update(get_model_config(_model_type))
+ dag_config["data_loader"] = _data_set
+ dr = (
+ driver.Builder()
+ .with_adapter(base.DefaultAdapter())
+ .with_config(dag_config)
+ .with_modules(data_loaders, model_training)
+ .build()
+ )
+ materializers = [
+ to.json(
+ dependencies=["model_parameters"], id="model_params_to_json", path="./data/params.json"
+ ),
+ # classification report to .txt file
+ to.file(
+ dependencies=["classification_report"],
+ id="classification_report_to_txt",
+ path="./data/classification_report.txt",
+ ),
+ # materialize the model to a pickle file
+ to.pickle(dependencies=["fit_clf"], id="clf_to_pickle", path="./data/clf.pkl"),
+ # materialize the predictions we made to a csv file
+ to.csv(
+ dependencies=["predicted_output_with_labels"],
+ id="predicted_output_with_labels_to_csv",
+ path="./data/predicted_output_with_labels.csv",
+ ),
+ ]
+ dr.visualize_materialization(
+ *materializers,
+ additional_vars=["classification_report"],
+ output_file_path="./dag",
+ render_kwargs={},
+ )
+ materialization_results, additional_vars = dr.materialize(
+ # materialize model parameters to json
+ *materializers,
+ additional_vars=["classification_report"],
+ )
+ # print(materialization_results["classification_report"])
+ # print(additional_vars)
diff --git a/hamilton/driver.py b/hamilton/driver.py
index 28e738b88..dab5f7d08 100644
--- a/hamilton/driver.py
+++ b/hamilton/driver.py
@@ -997,9 +997,9 @@ def materialize(
def visualize_materialization(
self,
*materializers: materialization.MaterializerFactory,
- additional_vars: List[Union[str, Callable, Variable]],
output_file_path: str,
render_kwargs: dict,
+ additional_vars: List[Union[str, Callable, Variable]] = None,
inputs: Dict[str, Any] = None,
graphviz_kwargs: dict = None,
) -> Optional["graphviz.Digraph"]: # noqa F821
@@ -1014,11 +1014,13 @@ def visualize_materialization(
:param graphviz_kwargs: Arguments to pass to graphviz
:return: The graphviz graph, if you want to do something with it
"""
+ if additional_vars is None:
+ additional_vars = []
function_graph = materialization.modify_graph(self.graph, materializers)
_final_vars = self._create_final_vars(additional_vars) + [
materializer.id for materializer in materializers
]
- Driver._visualize_execution_helper(
+ return Driver._visualize_execution_helper(
function_graph,
self.adapter,
_final_vars,
@@ -1177,9 +1179,7 @@ def build(self) -> Driver:
execution_manager = self.execution_manager
if execution_manager is None:
local_executor = self.local_executor or executors.SynchronousLocalTaskExecutor()
- remote_executor = self.remote_executor or executors.MultiProcessingExecutor(
- max_tasks=10
- )
+ remote_executor = self.remote_executor or executors.MultiThreadingExecutor(max_tasks=10)
execution_manager = executors.DefaultExecutionManager(
local_executor=local_executor, remote_executor=remote_executor
)
diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py
index a5deb9ae3..333561516 100644
--- a/hamilton/function_modifiers/adapters.py
+++ b/hamilton/function_modifiers/adapters.py
@@ -1,6 +1,6 @@
import inspect
import typing
-from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type
+from typing import Any, Callable, Collection, Dict, List, Tuple, Type
from hamilton import node
from hamilton.function_modifiers.base import (
@@ -139,7 +139,7 @@ def _select_param_to_inject(self, params: List[str], fn: Callable) -> str:
def inject_nodes(
self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable
- ) -> Optional[Collection[node.Node]]:
+ ) -> Tuple[Collection[node.Node], Dict[str, str]]:
pass
"""Generates two nodes:
1. A node that loads the data from the data source, and returns that + metadata
@@ -217,7 +217,7 @@ def get_input_type_key(key: str) -> str:
"hamilton.data_loader.classname": f"{loader_cls.__qualname__}",
"hamilton.data_loader.node": inject_parameter,
},
- namespace=("load_data", fn.__name__),
+ namespace=(fn.__name__, "load_data"),
)
# the filter node is the node that takes the data from the data source, filters out
@@ -239,8 +239,9 @@ def filter_function(_inject_parameter=inject_parameter, **kwargs):
"hamilton.data_loader.classname": f"{loader_cls.__qualname__}",
"hamilton.data_loader.node": inject_parameter,
},
+ namespace=(fn.__name__, "select_data"),
)
- return [loader_node, filter_node]
+ return [loader_node, filter_node], {inject_parameter: filter_node.name}
def _get_inject_parameter_from_function(self, fn: Callable) -> Tuple[str, Type[Type]]:
"""Gets the name of the parameter to inject the data into.
@@ -311,7 +312,9 @@ def __getattr__(cls, item: str):
f"Available loaders are: {LOADER_REGISTRY.keys()}. "
f"If you've gotten to this point, you either (1) spelled the "
f"loader name wrong, (2) are trying to use a loader that does"
- f"not exist (yet)"
+ f"not exist (yet). For a list of available loaders, see: "
+ f"https://hamilton.readthedocs.io/reference/io/available-data-adapters/#data"
+ f"-loaders "
) from e
@@ -424,11 +427,13 @@ def __getattr__(cls, item: str):
return super().__getattribute__(item)
except AttributeError as e:
raise AttributeError(
- f"No saver named: {item} available for {cls.__name__}. "
- f"Available data savers are: {list(SAVER_REGISTRY.keys())}. "
- f"If you've gotten to this point, you either (1) spelled the "
- f"loader name wrong, (2) are trying to use a saver that does"
- f"not exist (yet)."
+ "No saver named: {item} available for {cls.__name__}. "
+ "Available data savers are: {list(SAVER_REGISTRY.keys())}. "
+ "If you've gotten to this point, you either (1) spelled the "
+ "loader name wrong, (2) are trying to use a saver that does"
+ "not exist (yet). For a list of available savers, see "
+ "https://hamilton.readthedocs.io/reference/io/available-data-adapters/#data"
+ "-loaders "
) from e
diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py
index 808999365..debbe37f9 100644
--- a/hamilton/function_modifiers/base.py
+++ b/hamilton/function_modifiers/base.py
@@ -10,7 +10,7 @@
except ImportError:
# python3.10 and above
EllipsisType = type(...)
-from typing import Any, Callable, Collection, Dict, List, Optional, Type, Union
+from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union
from hamilton import node, registry, settings
@@ -228,6 +228,26 @@ def transform_dag(
pass
+# TODO -- delete this/replace with the version that will be added by
+# https://github.com/DAGWorks-Inc/hamilton/pull/249/ as part of the Node class
+def _reassign_input_names(node_: node.Node, input_names: Dict[str, Any]) -> node.Node:
+ """Reassigns the input names of a node. Useful for applying
+ a node to a separate input if needed. Note that things can get a
+ little strange if you have multiple inputs with the same name, so
+ be careful about how you use this.
+ :param input_names: Input name map to reassign
+ :return: A node with the input names reassigned
+ """
+
+ def new_callable(**kwargs) -> Any:
+ reverse_input_names = {v: k for k, v in input_names.items()}
+ return node_.callable(**{reverse_input_names.get(k, k): v for k, v in kwargs.items()})
+
+ new_input_types = {input_names.get(k, k): v for k, v in node_.input_types.items()}
+ out = node_.copy_with(callabl=new_callable, input_types=new_input_types)
+ return out
+
+
class NodeInjector(SubDAGModifier, abc.ABC):
"""Injects a value as a source node in the DAG. This is a special case of the SubDAGModifier,
which gets all the upstream (required) nodes from the subdag and gives the decorator a chance
@@ -275,21 +295,37 @@ def transform_dag(
:return:
"""
injectable_params = NodeInjector.find_injectable_params(nodes)
- out = list(nodes)
- out.extend(self.inject_nodes(injectable_params, config, fn))
+ nodes_to_inject, rename_map = self.inject_nodes(injectable_params, config, fn)
+ out = []
+ for node_ in nodes:
+ # if there's an intersection then we want to rename the input
+ if set(node_.input_types.keys()) & set(rename_map.keys()):
+ out.append(_reassign_input_names(node_, rename_map))
+ else:
+ out.append(node_)
+ out.extend(nodes_to_inject)
+ if len(set([node_.name for node_ in out])) != len(out):
+ import pdb
+
+ pdb.set_trace()
+ print([node_.name for node_ in out])
return out
@abc.abstractmethod
def inject_nodes(
self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable
- ) -> List[node.Node]:
+ ) -> Tuple[List[node.Node], Dict[str, str]]:
"""Adds a set of nodes to inject into the DAG. These get injected into the specified param name,
- meaning that exactly one of the output nodes will have that name.
+ meaning that exactly one of the output nodes will have that name. Note that this also allows
+ input renaming, meaning that the injector can rename the input to something else (to avoid
+ name-clashes).
:param params: Dictionary of all the type names one wants to inject
:param config: Configuration with which the DAG was constructed.
:param fn: original function we're decorating. This is useful largely for debugging.
- :return: A list of nodes to add. Empty if you wish to inject nothing
+ :return: A list of nodes to add. Empty if you wish to inject nothing, as well as a dictionary,
+ allowing the injector to rename the inputs (e.g. if you want the name to be
+ namespaced to avoid clashes)
"""
pass
diff --git a/hamilton/function_modifiers/dependencies.py b/hamilton/function_modifiers/dependencies.py
index 1d3dcbd7f..575ed7188 100644
--- a/hamilton/function_modifiers/dependencies.py
+++ b/hamilton/function_modifiers/dependencies.py
@@ -29,7 +29,7 @@ class SingleDependency(ParametrizedDependency, abc.ABC):
@dataclasses.dataclass
-class LiteralDependency(ParametrizedDependency):
+class LiteralDependency(SingleDependency):
value: Any
def get_dependency_type(self) -> ParametrizedDependencySource:
@@ -37,7 +37,7 @@ def get_dependency_type(self) -> ParametrizedDependencySource:
@dataclasses.dataclass
-class UpstreamDependency(ParametrizedDependency):
+class UpstreamDependency(SingleDependency):
source: str
def get_dependency_type(self) -> ParametrizedDependencySource:
diff --git a/hamilton/graph.py b/hamilton/graph.py
index 2f82fd645..57ec99ca1 100644
--- a/hamilton/graph.py
+++ b/hamilton/graph.py
@@ -68,19 +68,23 @@ def add_dependency(
def update_dependencies(
- nodes: Dict[str, node.Node], adapter: base.HamiltonGraphAdapter, in_place: bool = True
+ nodes: Dict[str, node.Node], adapter: base.HamiltonGraphAdapter, reset_dependencies: bool = True
):
- """Adds dependecies to a dictionary of nodes. If in_place is False,
+ """Adds dependencies to a dictionary of nodes. If in_place is False,
it will deepcopy the dict + nodes and return that. Otherwise it will
mutate + return the passed-in dict + nodes.
:param in_place: Whether or not to modify in-place, or copy/return
:param nodes: Nodes that form the DAG we're updating
:param adapter: Adapter to use for type checking
+ :param reset_dependencies: Whether or not to reset the dependencies. If they are not set this is
+ unnecessary, and we can save yet another pass. Note that `reset` will perform an in-place
+ operation.
:return: The updated nodes
"""
- if not in_place:
- nodes = {k: v for k, v in nodes.items()}
+ # copy without the dependencies to avoid duplicates
+ if reset_dependencies:
+ nodes = {k: v.copy(include_refs=False) for k, v in nodes.items()}
for node_name, n in list(nodes.items()):
for param_name, (param_type, _) in n.input_types.items():
add_dependency(n, node_name, nodes, param_name, param_type, adapter)
@@ -118,7 +122,8 @@ def create_function_graph(
)
nodes[n.name] = n
# add dependencies -- now that all nodes exist, we just run through edges & validate graph.
- update_dependencies(nodes, adapter) # in place
+ nodes = update_dependencies(nodes, adapter, reset_dependencies=False) # no dependencies
+ # present yet
for key in config.keys():
if key not in nodes:
nodes[key] = node.Node(key, Any, node_source=node.NodeType.EXTERNAL)
diff --git a/hamilton/io/data_adapters.py b/hamilton/io/data_adapters.py
index 965a1f20f..b997227e6 100644
--- a/hamilton/io/data_adapters.py
+++ b/hamilton/io/data_adapters.py
@@ -153,12 +153,12 @@ class DataSaver(AdapterCommon, abc.ABC):
@abc.abstractmethod
def save_data(self, data: Any) -> Dict[str, Any]:
"""Saves the data to the data source.
- Note this uses the constructor parameters to determine
- how to save the data.
+ Note this uses the constructor parameters to determine
+ how to save the data.
:return: Any relevant metadata. This is up the the data saver, but will likely
- include the URI, etc... This is going to be similar to the metadata returned
- by the data loader in the loading tuple.
+ include the URI, etc... This is going to be similar to the metadata returned
+ by the data loader in the loading tuple.
"""
pass
diff --git a/hamilton/io/materialization.py b/hamilton/io/materialization.py
index 3688d7a6f..09952cb82 100644
--- a/hamilton/io/materialization.py
+++ b/hamilton/io/materialization.py
@@ -7,7 +7,7 @@
from hamilton.function_modifiers.dependencies import SingleDependency, value
from hamilton.graph import FunctionGraph
from hamilton.io.data_adapters import DataSaver
-from hamilton.registry import LOADER_REGISTRY
+from hamilton.registry import SAVER_REGISTRY
class materialization_meta__(type):
@@ -19,20 +19,22 @@ class in registry, or make it a function that just proxies to the decorator. We
"""
def __getattr__(cls, item: str):
- if item in LOADER_REGISTRY:
- potential_loaders = LOADER_REGISTRY[item]
+ if item in SAVER_REGISTRY:
+ potential_loaders = SAVER_REGISTRY[item]
savers = [loader for loader in potential_loaders if issubclass(loader, DataSaver)]
if len(savers) > 0:
- return Materialize.partial(LOADER_REGISTRY[item])
+ return Materialize.partial(SAVER_REGISTRY[item])
try:
return super().__getattribute__(item)
except AttributeError as e:
raise AttributeError(
- f"No loader named: {item} available for {cls.__name__}. "
- f"Available loaders are: {LOADER_REGISTRY.keys()}. "
- f"If you've gotten to this point, you either (1) spelled the "
- f"loader name wrong, (2) are trying to use a loader that does"
- f"not exist (yet)"
+ "No data materializer named: {item}. "
+ "Available materializers are: {SAVER_REGISTRY.keys()}. "
+ "If you've gotten to this point, you either (1) spelled the "
+ "loader name wrong, (2) are trying to use a loader that does"
+ "not exist (yet). For a list of available materializers, see "
+ "https://hamilton.readthedocs.io/reference/io/available-data-adapters/#data"
+ "-loaders "
) from e
@@ -76,6 +78,7 @@ def _process_kwargs(
"""
processed_kwargs = {}
for kwarg, kwarg_val in data_saver_kwargs.items():
+
if not isinstance(kwarg_val, SingleDependency):
processed_kwargs[kwarg] = value(kwarg_val)
else:
diff --git a/hamilton/node.py b/hamilton/node.py
index 50f4a5c9b..3bba90aa6 100644
--- a/hamilton/node.py
+++ b/hamilton/node.py
@@ -108,7 +108,7 @@ def __init__(
DependencyType.from_parameter(value),
)
elif self.user_defined:
- if input_types is not None:
+ if len(self._input_types) > 0:
raise ValueError(
f"Input types cannot be provided for user-defined node {self.name}"
)
@@ -266,11 +266,12 @@ def from_fn(fn: Callable, name: str = None) -> "Node":
node_source=node_source,
)
- def copy_with(self, **overrides) -> "Node":
+ def copy_with(self, include_refs: bool = True, **overrides) -> "Node":
"""Copies a node with the specified overrides for the constructor arguments.
Utility function for creating a node -- useful for modifying it.
:param kwargs: kwargs to use in place of the node. Passed to the constructor.
+ :param include_refs: Whether or not to include dependencies and depended_on_by
:return: A node copied from self with the specified keyword arguments replaced.
"""
constructor_args = dict(
@@ -284,4 +285,20 @@ def copy_with(self, **overrides) -> "Node":
originating_functions=self.originating_functions,
)
constructor_args.update(**overrides)
- return Node(**constructor_args)
+ out = Node(**constructor_args)
+ if include_refs:
+ out._dependencies = self._dependencies
+ out._depended_on_by = self._depended_on_by
+ return out
+
+ def copy(self, include_refs: bool = True) -> "Node":
+ """Copies a node, not modifying anything (except for the references
+ /dependencies if specified).
+
+ :param include_refs: Whether or not to include dependencies and depended_on_by
+ :return: A copy of the node.
+ """
+ """Gives a copy of the node, so we can modify it without modifying the original.
+ :return: A copy of the node.
+ """
+ return self.copy_with(include_refs)
diff --git a/requirements-docs.txt b/requirements-docs.txt
index 36137690b..2403cff5d 100644
--- a/requirements-docs.txt
+++ b/requirements-docs.txt
@@ -3,6 +3,7 @@ alabaster>=0.7,<0.8,!=0.7.5 # read the docs pins
commonmark==0.9.1 # read the docs pins
dask[distributed]
furo
+gitpython # Required for parsing git info for generation of data-adapter docs
mock==1.0.1 # read the docs pins
myst-parser==0.18.1 # latest version of myst at this time
pillow
diff --git a/tests/function_modifiers/test_adapters.py b/tests/function_modifiers/test_adapters.py
index 49cbda1f6..cba3470f2 100644
--- a/tests/function_modifiers/test_adapters.py
+++ b/tests/function_modifiers/test_adapters.py
@@ -75,17 +75,17 @@ def fn(data: int) -> int:
nodes_by_name = {node_.name: node_ for node_ in nodes}
assert len(nodes_by_name) == 3
assert "fn" in nodes_by_name
- assert nodes_by_name["data"].tags == {
+ assert nodes_by_name["fn.load_data.data"].tags == {
"hamilton.data_loader.source": "mock",
"hamilton.data_loader": True,
- "hamilton.data_loader.has_metadata": False,
+ "hamilton.data_loader.has_metadata": True,
"hamilton.data_loader.node": "data",
"hamilton.data_loader.classname": MockDataLoader.__qualname__,
}
- assert nodes_by_name["load_data.fn.data"].tags == {
+ assert nodes_by_name["fn.select_data.data"].tags == {
"hamilton.data_loader.source": "mock",
"hamilton.data_loader": True,
- "hamilton.data_loader.has_metadata": True,
+ "hamilton.data_loader.has_metadata": False,
"hamilton.data_loader.node": "data",
"hamilton.data_loader.classname": MockDataLoader.__qualname__,
}
@@ -333,7 +333,7 @@ def fn_str_inject(injected_data: str) -> str:
)
result = fg.execute(inputs={}, nodes=fg.nodes.values())
assert result["fn_str_inject"] == "foo"
- assert result["load_data.fn_str_inject.injected_data"] == (
+ assert result["fn_str_inject.load_data.injected_data"] == (
"foo",
{"loader": "string_data_loader"},
)
@@ -362,12 +362,12 @@ def fn_str_inject(injected_data_1: str, injected_data_2: int) -> str:
)
result = fg.execute(inputs={}, nodes=fg.nodes.values())
assert result["fn_str_inject"] == "foofoo"
- assert result["load_data.fn_str_inject.injected_data_1"] == (
+ assert result["fn_str_inject.load_data.injected_data_1"] == (
"foo",
{"loader": "string_data_loader"},
)
- assert result["load_data.fn_str_inject.injected_data_2"] == (
+ assert result["fn_str_inject.load_data.injected_data_2"] == (
2,
{"loader": "int_data_loader_2"},
)
diff --git a/tests/test_graph.py b/tests/test_graph.py
index 0431b0893..96d908506 100644
--- a/tests/test_graph.py
+++ b/tests/test_graph.py
@@ -220,15 +220,6 @@ def test_add_dependency_user_nodes():
assert func_node.depended_on_by == []
-def test_create_function_graph_simple():
- """Tests that we create a simple function graph."""
- expected = create_testing_nodes()
- actual = graph.create_function_graph(
- tests.resources.dummy_functions, config={}, adapter=base.SimplePythonDataFrameGraphAdapter()
- )
- assert actual == expected
-
-
def create_testing_nodes():
"""Helper function for creating the nodes represented in dummy_functions.py."""
nodes = {
@@ -275,6 +266,15 @@ def create_testing_nodes():
return nodes
+def test_create_function_graph_simple():
+ """Tests that we create a simple function graph."""
+ expected = create_testing_nodes()
+ actual = graph.create_function_graph(
+ tests.resources.dummy_functions, config={}, adapter=base.SimplePythonDataFrameGraphAdapter()
+ )
+ assert actual == expected
+
+
def test_execute():
"""Tests graph execution along with basic memoization since A is depended on by two functions."""
adapter = base.SimplePythonDataFrameGraphAdapter()
@@ -800,3 +800,11 @@ def my_function(A: int, b: int, c: int) -> int:
)
results = fg.execute([n for n in fg.get_nodes() if n.name in ["my_function", "A"]])
assert results == {"A": 4, "b": 3, "c": 1, "my_function": 8}
+
+
+def test_update_dependencies():
+ nodes = create_testing_nodes()
+ new_nodes = graph.update_dependencies(nodes, base.DefaultAdapter())
+ for node_name, node_ in new_nodes.items():
+ assert node_.dependencies == nodes[node_name].dependencies
+ assert node_.depended_on_by == nodes[node_name].depended_on_by