-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdag.py
51 lines (42 loc) · 1.26 KB
/
dag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.operators.bash import BashOperator
from datetime import datetime
from random import randint
def _choosing_best_model(ti):
accuracies = ti.xcom_pull(task_ids=[
'model_A',
'model_B',
'model_C'
])
if max(accuracies) > 8:
return 'accurate'
return 'inaccurate'
def _training_model(model):
return randint(1, 10)
with DAG("test",
start_date=datetime(2021, 1 ,1),
schedule_interval='@daily',
catchup=False) as dag:
training_model_tasks = [
PythonOperator(
task_id=f"training_model_{model_id}",
python_callable=_training_model,
op_kwargs={
"model": model_id
}
) for model_id in ['A', 'B', 'C']
]
choosing_best_model = BranchPythonOperator(
task_id="choosing_best_model",
python_callable=_choosing_best_model
)
accurate = BashOperator(
task_id="accurate",
bash_command="echo 'accurate'"
)
inaccurate = BashOperator(
task_id="inaccurate",
bash_command=" echo 'inaccurate'"
)
training_model_tasks >> choosing_best_model >> [accurate, inaccurate]