Skip to content

Commit

Permalink
refactor: all run and schedule args can optionally configured via con…
Browse files Browse the repository at this point in the history
…f files
  • Loading branch information
legout committed Aug 21, 2024
1 parent 8602ffb commit 9514a5f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 84 deletions.
Binary file added .DS_Store
Binary file not shown.
23 changes: 7 additions & 16 deletions src/flowerpower/helpers/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def _get_cron_trigger(
if crontab is not None:
return (CronTrigger.from_crontab(crontab), kwargs)
else:
return (
CronTrigger(
return CronTrigger(
year=kwargs.pop("year", None),
month=kwargs.pop("month", None),
week=kwargs.pop("week", None),
Expand All @@ -103,9 +102,7 @@ def _get_cron_trigger(
start_time=start_time,
end_time=end_time,
timezone=timezone,
),
kwargs,
)
)

def _get_interval_trigger(
self,
Expand All @@ -115,8 +112,7 @@ def _get_interval_trigger(
):
from apscheduler.triggers.interval import IntervalTrigger

return (
IntervalTrigger(
return IntervalTrigger(
weeks=kwargs.pop("weeks", 0),
days=kwargs.pop("days", 0),
hours=kwargs.pop("hours", 0),
Expand All @@ -125,9 +121,7 @@ def _get_interval_trigger(
microseconds=kwargs.pop("microseconds", 0),
start_time=start_time,
end_time=end_time,
),
kwargs,
)
)

def _get_calendar_trigger(
self,
Expand All @@ -138,8 +132,7 @@ def _get_calendar_trigger(
):
from apscheduler.triggers.calendarinterval import CalendarIntervalTrigger

return (
CalendarIntervalTrigger(
return CalendarIntervalTrigger(
weeks=kwargs.pop("weeks", 0),
days=kwargs.pop("days", 0),
hours=kwargs.pop("hours", 0),
Expand All @@ -148,14 +141,12 @@ def _get_calendar_trigger(
start_time=start_time,
end_time=end_time,
timezone=timezone,
),
kwargs,
)
)

def _get_date_trigger(self, start_time: dt.datetime, **kwargs):
from apscheduler.triggers.date import DateTrigger

return (DateTrigger(run_time=start_time), kwargs)
return DateTrigger(run_time=start_time)


def get_trigger(trigger_type: str, **kwargs):
Expand Down
124 changes: 56 additions & 68 deletions src/flowerpower/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib
import os
import sys
from tkinter import ALL
from typing import Any, Callable
from uuid import UUID

Expand All @@ -21,7 +22,7 @@


from .helpers.executor import get_executor
from .helpers.trigger import get_trigger
from .helpers.trigger import get_trigger, ALL_TRIGGER_KWARGS


class PipelineManager:
Expand Down Expand Up @@ -136,36 +137,26 @@ def _get_driver(
self._load_module(name)

if with_tracker:
project_id = kwargs.pop("project_id", None) or self.cfg.tracker.pipeline[
name
].get("project_id", None)
username = kwargs.pop("username", None) or self.cfg.tracker.get(
"username", None
)
dag_name = kwargs.pop("dag_name", None) or self.cfg.tracker.pipeline[
name
].get("dag_name", None)
tags = kwargs.pop("tags", None) or self.cfg.tracker.pipeline[name].get(
"tags", None
)
api_url = kwargs.pop("api_url", None) or self.cfg.tracker.get(
"api_url", None
)
ui_url = kwargs.pop("ui_url", None) or self.cfg.tracker.get("ui_url", None)
tracker_cfg = self.cfg.tracker.pipeline.get(name, {})
tracker_kwargs = {
key: kwargs.pop(key, None) or tracker_cfg.get(key, None)
for key in [
"project_id",
"username",
"dag_name",
"tags",
"api_url",
"ui_url",
]
}
project_id = tracker_kwargs.get("project_id", None)

if project_id is None:
raise ValueError(
"Please provide a project_id if you want to use the tracker"
)

tracker = adapters.HamiltonTracker(
project_id=project_id,
username=username,
dag_name=dag_name,
tags=tags,
hamilton_api_url=api_url,
hamilton_ui_url=ui_url,
)
tracker = adapters.HamiltonTracker(project_id=project_id, **tracker_kwargs)

dr = (
driver.Builder()
Expand All @@ -190,9 +181,9 @@ def _run(
self,
name: str,
environment: str = "dev",
executor: str | None = None,
inputs: dict | None = None,
final_vars: list | None = None,
executor: str | None = None,
with_tracker: bool | None = None,
reload: bool = False,
**kwargs,
Expand All @@ -215,18 +206,23 @@ def _run(
"""
logger.info(f"Starting pipeline {name} in environment {environment}")

pipeline_cfg = self.cfg.pipeline
run_params = pipeline_cfg.run.get(name)[environment]
run_params = self.cfg.pipeline.run.get(name)[environment]

final_vars = final_vars or run_params.get("final_vars", [])
inputs = {**(run_params.get("inputs", {}) or {}), **(inputs or {})}
with_tracker = with_tracker or run_params.get("with_tracker", False)
inputs = {
**(run_params.get("inputs", {}) or {}),
**(inputs or {}),
} # <-- inputs override and adds to run_params

kwargs.update(
{
arg: eval(arg) or run_params.get(arg, None)
for arg in ["executor", "with_tracker", "reload"]
}
)

dr, shutdown = self._get_driver(
name=name,
executor=executor,
with_tracker=with_tracker,
reload=reload,
**kwargs,
)

Expand Down Expand Up @@ -362,8 +358,6 @@ def add_job(
with SchedulerManager(
name=name, base_dir=self._base_dir, role="scheduler"
) as sm:
# if not any([task.id == "run-pipeline" for task in sm.get_tasks()]):
# sm.configure_task(func_or_task_id="run-pipeline", func=self._run)
return sm.add_job(
self._run,
args=(
Expand All @@ -385,11 +379,11 @@ def add_job(
def schedule(
self,
name: str,
inputs: dict | None = None,
final_vars: list | None = None,
environment: str = "dev",
executor: str | None = None,
trigger_type: str | None = None,
inputs: dict | None = None,
final_vars: list | None = None,
with_tracker: bool | None = None,
paused: bool = False,
coalesce: str = "latest",
Expand Down Expand Up @@ -431,49 +425,43 @@ def schedule(
if SchedulerManager is None:
raise ValueError("APScheduler4 not installed. Please install it first.")

trigger_kwargs = {}
if "pipeline" in self.cfg.scheduler:
scheduler_cfg = self.cfg.scheduler.pipeline.get(name, None).copy()
scheduler_cfg = self.cfg.scheduler.pipeline.get(name, None) # .copy()
else:
scheduler_cfg = None

if scheduler_cfg is not None:
trigger_type = trigger_type or scheduler_cfg.pop("trigger_type", None)
for key in [
"crontab",
"year",
"month",
"week",
"day",
"days_of_week",
"hour",
"minute",
"second",
"timezone",
]:
trigger_kwargs[key] = scheduler_cfg.pop(key, None)
scheduler_cfg = {}

trigger_type = trigger_type or scheduler_cfg.get("trigger_type", None)

trigger_kwargs = {
key: kwargs.pop(key, None) or scheduler_cfg.get(key, None)
for key in ALL_TRIGGER_KWARGS.get(trigger_type, [])
if key in kwargs or key in scheduler_cfg
}

schedule_kwargs = {
arg: eval(arg) or scheduler_cfg.get(arg, None)
for arg in [
"executor",
"paused",
"coalesce",
"misfire_grace_time",
"max_jitter",
"max_running_jobs",
"conflict_policy",
]
}

with SchedulerManager(
name=name, base_dir=self._base_dir, role="scheduler"
) as sm:
# if not any([task.id == "run-pipeline" for task in sm.get_tasks()]):
# sm.configure_task(func_or_task_id="run-pipeline", func=self._run)
trigger, kwargs = get_trigger(trigger_type, **kwargs)
trigger = get_trigger(trigger_type, **trigger_kwargs)

id_ = sm.add_schedule(
self._run,
trigger=trigger,
args=(name, environment, executor, inputs, final_vars, with_tracker),
kwargs=kwargs,
job_executor=executor
if executor in ["async", "threadpool", "processpool"]
else "async",
paused=paused,
coalesce=coalesce,
misfire_grace_time=misfire_grace_time,
max_jitter=max_jitter,
max_running_jobs=max_running_jobs,
conflict_policy=conflict_policy,
**schedule_kwargs,
)
logger.success(
f"Added scheduler for {name} in environment {environment} with id {id_}"
Expand Down

0 comments on commit 9514a5f

Please sign in to comment.