Skip to content

Commit

Permalink
simplified to only table-level dataflow
Browse files Browse the repository at this point in the history
  • Loading branch information
zilto authored and skrawcz committed Mar 5, 2024
1 parent b982ff2 commit 26dc7c8
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 90 deletions.
61 changes: 0 additions & 61 deletions examples/ibisml/column_dataflow.py

This file was deleted.

Binary file modified examples/ibisml/cross_validation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed examples/ibisml/ibis_feature_set.png
Binary file not shown.
6 changes: 3 additions & 3 deletions examples/ibisml/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def prepare_data(
train = transform(train_set)
df_train = train.to_pandas()
X_train = df_train[train.features]
y_train = df_train[train.outcomes]
y_train = df_train[train.outcomes].to_numpy().reshape(-1)

df_test = transform(val_set).to_pandas()
X_val = df_test[train.features]
y_val = df_test[train.outcomes]
y_val = df_test[train.outcomes].to_numpy().reshape(-1)

return dict(
X_train=X_train,
Expand Down Expand Up @@ -161,7 +161,7 @@ def train_full_model(
data = transform(feature_set)
df = data.to_pandas()
X = df[data.features]
y = df[data.outcomes]
y = df[data.outcomes].to_numpy().reshape(-1)

base_model.fit(X, y)
return dict(
Expand Down
34 changes: 8 additions & 26 deletions examples/ibisml/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from hamilton import driver
from hamilton.plugins.h_tqdm import ProgressBar
from hamilton.execution.executors import SynchronousLocalTaskExecutor
from hamilton.plugins.h_tqdm import ProgressBar


def view_expression(expression, **kwargs):
Expand All @@ -16,33 +16,17 @@ def view_expression(expression, **kwargs):
return dot


def main(level: str, model: str):
dataflow_components = []
config = {}
final_vars = ["feature_set"]

if level == "column":
import column_dataflow

dataflow_components.append(column_dataflow)
elif level == "table":
import table_dataflow

dataflow_components.append(table_dataflow)
else:
raise ValueError("`level` must be in ['column', 'table']")

if model:
import model_training
def main(model: str):
import model_training
import table_dataflow

dataflow_components.append(model_training)
config["model"] = model
final_vars.extend(["full_model", "fitted_recipe", "cross_validation_scores"])
config = {"model": model}
final_vars = ["full_model", "fitted_recipe", "cross_validation_scores"]

# build the Driver from modules
dr = (
driver.Builder()
.with_modules(*dataflow_components)
.with_modules(table_dataflow, model_training)
.with_config(config)
.with_adapters(ProgressBar())
.enable_dynamic_execution(allow_experimental_mode=True)
Expand All @@ -68,7 +52,6 @@ def main(level: str, model: str):
)

res = dr.execute(final_vars, inputs=inputs)
view_expression(res["feature_set"], filename="ibis_feature_set", format="png")

print("Dataflow result keys: ", list(res.keys()))

Expand All @@ -77,8 +60,7 @@ def main(level: str, model: str):
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--level", choices=["column", "table"])
parser.add_argument("--model", choices=["linear", "random_forest", "boosting"])
args = parser.parse_args()

main(level=args.level, model=args.model)
main(model=args.model)

0 comments on commit 26dc7c8

Please sign in to comment.