diff --git a/airflow/__init__.py b/airflow/__init__.py
index db3fcd611c740..1ed188cc45886 100644
--- a/airflow/__init__.py
+++ b/airflow/__init__.py
@@ -31,7 +31,7 @@
from airflow.models import DAG
from flask_admin import BaseView
from importlib import import_module
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
if DAGS_FOLDER not in sys.path:
diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index d9a8d0667e44a..1531c18aaa5dc 100755
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -14,11 +14,14 @@
import json
import airflow
-from airflow import jobs, settings, utils
+from airflow import jobs, settings
from airflow import configuration as conf
from airflow.executors import DEFAULT_EXECUTOR
from airflow.models import DagModel, DagBag, TaskInstance, DagPickle, DagRun
-from airflow.utils import AirflowException, State
+from airflow.utils import db as db_utils
+from airflow.utils import logging as logging_utils
+from airflow.utils.state import State
+from airflow.exceptions import AirflowException
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
@@ -78,7 +81,8 @@ def backfill(args, dag=None):
mark_success=args.mark_success,
include_adhoc=args.include_adhoc,
local=args.local,
- donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')),
+ donot_pickle=(args.donot_pickle or
+ conf.getboolean('core', 'donot_pickle')),
ignore_dependencies=args.ignore_dependencies,
pool=args.pool)
@@ -133,7 +137,7 @@ def set_is_paused(is_paused, args, dag=None):
def run(args, dag=None):
- utils.pessimistic_connection_handling()
+ db_utils.pessimistic_connection_handling()
if dag:
args.dag_id = dag.dag_id
@@ -236,10 +240,10 @@ def run(args, dag=None):
remote_log_location = filename.replace(log_base, remote_base)
# S3
if remote_base.startswith('s3:/'):
- utils.S3Log().write(log, remote_log_location)
+ logging_utils.S3Log().write(log, remote_log_location)
# GCS
elif remote_base.startswith('gs:/'):
- utils.GCSLog().write(
+ logging_utils.GCSLog().write(
log,
remote_log_location,
append=True)
@@ -401,7 +405,7 @@ def worker(args):
def initdb(args): # noqa
print("DB: " + repr(settings.engine.url))
- utils.initdb()
+ db_utils.initdb()
print("Done.")
@@ -412,14 +416,14 @@ def resetdb(args):
"Proceed? (y/n)").upper() == "Y":
logging.basicConfig(level=settings.LOGGING_LEVEL,
format=settings.SIMPLE_LOG_FORMAT)
- utils.resetdb()
+ db_utils.resetdb()
else:
print("Bail.")
def upgradedb(args): # noqa
print("DB: " + repr(settings.engine.url))
- utils.upgradedb()
+ db_utils.upgradedb()
def version(args): # noqa
diff --git a/airflow/configuration.py b/airflow/configuration.py
index 54c3e5ca0ef53..04b555250aad9 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -465,6 +465,7 @@ def read(self, filenames):
ConfigParser.read(self, filenames)
self._validate()
+
def mkdir_p(path):
try:
os.makedirs(path)
@@ -534,6 +535,7 @@ def test_mode():
def get(section, key, **kwargs):
return conf.get(section, key, **kwargs)
+
def getboolean(section, key):
return conf.getboolean(section, key)
@@ -549,14 +551,17 @@ def getint(section, key):
def has_option(section, key):
return conf.has_option(section, key)
+
def remove_option(section, option):
return conf.remove_option(section, option)
+
def set(section, option, value): # noqa
return conf.set(section, option, value)
########################
# convenience method to access config entries
+
def get_dags_folder():
return os.path.expanduser(get('core', 'DAGS_FOLDER'))
diff --git a/airflow/contrib/executors/mesos_executor.py b/airflow/contrib/executors/mesos_executor.py
index 3b82306f2bd5a..45a474dc3a1a3 100644
--- a/airflow/contrib/executors/mesos_executor.py
+++ b/airflow/contrib/executors/mesos_executor.py
@@ -11,8 +11,8 @@
from airflow import configuration
from airflow.executors.base_executor import BaseExecutor
from airflow.settings import Session
-from airflow.utils import State
-from airflow.utils import AirflowException
+from airflow.utils.state import State
+from airflow.exceptions import AirflowException
DEFAULT_FRAMEWORK_NAME = 'Airflow'
diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py
index 1507b09fef276..46c6a806dd099 100644
--- a/airflow/contrib/hooks/__init__.py
+++ b/airflow/contrib/hooks/__init__.py
@@ -1,6 +1,6 @@
# Imports the hooks dynamically while keeping the package API clean,
# abstracting the underlying modules
-from airflow.utils import import_module_attrs as _import_module_attrs
+from airflow.utils.helpers import import_module_attrs as _import_module_attrs
_hooks = {
'ftp_hook': ['FTPHook'],
diff --git a/airflow/contrib/hooks/gc_base_hook.py b/airflow/contrib/hooks/gc_base_hook.py
index b17d37fed5fe0..6af01e79eeffb 100644
--- a/airflow/contrib/hooks/gc_base_hook.py
+++ b/airflow/contrib/hooks/gc_base_hook.py
@@ -2,7 +2,7 @@
import logging
from airflow.hooks.base_hook import BaseHook
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from oauth2client.client import SignedJwtAssertionCredentials, GoogleCredentials
class GoogleCloudBaseHook(BaseHook):
diff --git a/airflow/contrib/hooks/qubole_hook.py b/airflow/contrib/hooks/qubole_hook.py
index 7e2fdb32dcbcc..c36b9f5838f5e 100755
--- a/airflow/contrib/hooks/qubole_hook.py
+++ b/airflow/contrib/hooks/qubole_hook.py
@@ -3,7 +3,7 @@
import datetime
import logging
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
from airflow import configuration
@@ -151,4 +151,4 @@ def create_cmd_args(self):
else:
args += inplace_args.split(' ')
- return args
\ No newline at end of file
+ return args
diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py
index bc590fe2789f8..c5aff848bd865 100644
--- a/airflow/contrib/hooks/ssh_hook.py
+++ b/airflow/contrib/hooks/ssh_hook.py
@@ -20,7 +20,7 @@
from contextlib import contextmanager
from airflow.hooks.base_hook import BaseHook
-from airflow import AirflowException
+from airflow.exceptions import AirflowException
import logging
diff --git a/airflow/contrib/operators/__init__.py b/airflow/contrib/operators/__init__.py
index f178e392f2870..3598490a8eeda 100644
--- a/airflow/contrib/operators/__init__.py
+++ b/airflow/contrib/operators/__init__.py
@@ -1,6 +1,6 @@
# Imports the operators dynamically while keeping the package API clean,
# abstracting the underlying modules
-from airflow.utils import import_module_attrs as _import_module_attrs
+from airflow.utils.helpers import import_module_attrs as _import_module_attrs
_operators = {
'ssh_execute_operator': ['SSHExecuteOperator'],
diff --git a/airflow/contrib/operators/bigquery_check_operator.py b/airflow/contrib/operators/bigquery_check_operator.py
index 69b69c37f36da..218de5a10ec45 100644
--- a/airflow/contrib/operators/bigquery_check_operator.py
+++ b/airflow/contrib/operators/bigquery_check_operator.py
@@ -1,6 +1,6 @@
from airflow.contrib.hooks.bigquery_hook import BigQueryHook
from airflow.operators import CheckOperator, ValueCheckOperator, IntervalCheckOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class BigQueryCheckOperator(CheckOperator):
diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py
index 78edde2ca963c..2f60ac60ab06d 100644
--- a/airflow/contrib/operators/bigquery_operator.py
+++ b/airflow/contrib/operators/bigquery_operator.py
@@ -2,7 +2,7 @@
from airflow.contrib.hooks.bigquery_hook import BigQueryHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class BigQueryOperator(BaseOperator):
"""
diff --git a/airflow/contrib/operators/bigquery_to_bigquery.py b/airflow/contrib/operators/bigquery_to_bigquery.py
index ccb9b0715f894..56023f584041b 100644
--- a/airflow/contrib/operators/bigquery_to_bigquery.py
+++ b/airflow/contrib/operators/bigquery_to_bigquery.py
@@ -2,7 +2,7 @@
from airflow.contrib.hooks.bigquery_hook import BigQueryHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class BigQueryToBigQueryOperator(BaseOperator):
"""
diff --git a/airflow/contrib/operators/bigquery_to_gcs.py b/airflow/contrib/operators/bigquery_to_gcs.py
index 012b3bb4a6ffb..3a543fdb3052c 100644
--- a/airflow/contrib/operators/bigquery_to_gcs.py
+++ b/airflow/contrib/operators/bigquery_to_gcs.py
@@ -2,7 +2,7 @@
from airflow.contrib.hooks.bigquery_hook import BigQueryHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class BigQueryToCloudStorageOperator(BaseOperator):
"""
diff --git a/airflow/contrib/operators/gcs_download_operator.py b/airflow/contrib/operators/gcs_download_operator.py
index ef917e85de219..8de6d1728af55 100644
--- a/airflow/contrib/operators/gcs_download_operator.py
+++ b/airflow/contrib/operators/gcs_download_operator.py
@@ -2,7 +2,7 @@
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class GoogleCloudStorageDownloadOperator(BaseOperator):
"""
diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py
index bcb418c636ac7..ec3a459238bb1 100644
--- a/airflow/contrib/operators/gcs_to_bq.py
+++ b/airflow/contrib/operators/gcs_to_bq.py
@@ -4,7 +4,7 @@
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
from airflow.contrib.hooks.bigquery_hook import BigQueryHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class GoogleCloudStorageToBigQueryOperator(BaseOperator):
"""
diff --git a/airflow/contrib/operators/mysql_to_gcs.py b/airflow/contrib/operators/mysql_to_gcs.py
index cee634ee12871..0eb368e3fa95e 100644
--- a/airflow/contrib/operators/mysql_to_gcs.py
+++ b/airflow/contrib/operators/mysql_to_gcs.py
@@ -5,7 +5,7 @@
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
from airflow.hooks import MySqlHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
from collections import OrderedDict
from datetime import date, datetime
from decimal import Decimal
diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py
index 9c05479827b6b..2ed94e1023402 100755
--- a/airflow/contrib/operators/qubole_operator.py
+++ b/airflow/contrib/operators/qubole_operator.py
@@ -1,5 +1,5 @@
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
from airflow.contrib.hooks import QuboleHook
diff --git a/airflow/contrib/operators/ssh_execute_operator.py b/airflow/contrib/operators/ssh_execute_operator.py
index 0c20719660cba..c55f0d177c3b2 100644
--- a/airflow/contrib/operators/ssh_execute_operator.py
+++ b/airflow/contrib/operators/ssh_execute_operator.py
@@ -4,8 +4,8 @@
from subprocess import STDOUT
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
-from airflow.utils import AirflowException
+from airflow.utils.decorators import apply_defaults
+from airflow.exceptions import AirflowException
class SSHTempFileContent():
diff --git a/airflow/contrib/operators/vertica_operator.py b/airflow/contrib/operators/vertica_operator.py
index 08003114d5fde..9e5248f03fcb3 100644
--- a/airflow/contrib/operators/vertica_operator.py
+++ b/airflow/contrib/operators/vertica_operator.py
@@ -2,7 +2,7 @@
from airflow.contrib.hooks import VerticaHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class VerticaOperator(BaseOperator):
diff --git a/airflow/contrib/operators/vertica_to_hive.py b/airflow/contrib/operators/vertica_to_hive.py
index 17a59680b76d8..35a489a9beba3 100644
--- a/airflow/contrib/operators/vertica_to_hive.py
+++ b/airflow/contrib/operators/vertica_to_hive.py
@@ -7,7 +7,7 @@
from airflow.hooks import HiveCliHook
from airflow.contrib.hooks import VerticaHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class VerticaToHiveTransfer(BaseOperator):
"""
diff --git a/airflow/example_dags/example_short_circuit_operator.py b/airflow/example_dags/example_short_circuit_operator.py
index 3a4c6315eae9b..967c65edc3f5c 100644
--- a/airflow/example_dags/example_short_circuit_operator.py
+++ b/airflow/example_dags/example_short_circuit_operator.py
@@ -1,6 +1,6 @@
from airflow.operators import ShortCircuitOperator, DummyOperator
from airflow.models import DAG
-import airflow.utils
+import airflow.utils.helpers
from datetime import datetime, timedelta
seven_days_ago = datetime.combine(datetime.today() - timedelta(7),
@@ -21,5 +21,5 @@
ds_true = [DummyOperator(task_id='true_' + str(i), dag=dag) for i in [1, 2]]
ds_false = [DummyOperator(task_id='false_' + str(i), dag=dag) for i in [1, 2]]
-airflow.utils.chain(cond_true, *ds_true)
-airflow.utils.chain(cond_false, *ds_false)
+airflow.utils.helpers.chain(cond_true, *ds_true)
+airflow.utils.helpers.chain(cond_false, *ds_false)
diff --git a/airflow/example_dags/example_trigger_controller_dag.py b/airflow/example_dags/example_trigger_controller_dag.py
index 3b463a4d23c87..657672e305198 100644
--- a/airflow/example_dags/example_trigger_controller_dag.py
+++ b/airflow/example_dags/example_trigger_controller_dag.py
@@ -1,3 +1,4 @@
+
"""This example illustrates the use of the TriggerDagRunOperator. There are 2
entities at work in this scenario:
1. The Controller DAG - the DAG that conditionally executes the trigger
@@ -14,6 +15,7 @@
state is then made available to the TargetDag
2. A Target DAG : c.f. example_trigger_target_dag.py
"""
+
from airflow import DAG
from airflow.operators import TriggerDagRunOperator
from datetime import datetime
@@ -35,8 +37,8 @@ def conditionally_trigger(context, dag_run_obj):
# Define the DAG
dag = DAG(dag_id='example_trigger_controller_dag',
- default_args={"owner" : "me",
- "start_date":datetime.now()},
+ default_args={"owner": "me",
+ "start_date": datetime.now()},
schedule_interval='@once')
diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py
index 9d548813b10f0..172003f05fc2a 100644
--- a/airflow/example_dags/example_trigger_target_dag.py
+++ b/airflow/example_dags/example_trigger_target_dag.py
@@ -34,7 +34,7 @@
def run_this_func(ds, **kwargs):
- print( "Remotely received value of {} for key=message".format(kwargs['dag_run'].conf['message']))
+ print("Remotely received value of {} for key=message".format(kwargs['dag_run'].conf['message']))
run_this = PythonOperator(
task_id='run_this',
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
new file mode 100644
index 0000000000000..2468643ee47b7
--- /dev/null
+++ b/airflow/exceptions.py
@@ -0,0 +1,10 @@
+class AirflowException(Exception):
+ pass
+
+
+class AirflowSensorTimeout(Exception):
+ pass
+
+
+class AirflowTaskTimeout(Exception):
+ pass
diff --git a/airflow/executors/__init__.py b/airflow/executors/__init__.py
index 695ef5ce7f24b..31635a1374725 100644
--- a/airflow/executors/__init__.py
+++ b/airflow/executors/__init__.py
@@ -10,7 +10,7 @@
except:
pass
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
_EXECUTOR = configuration.get('core', 'EXECUTOR')
diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py
index db68860d66ab1..a0c26ebcbb237 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -1,7 +1,8 @@
from builtins import range
from airflow import configuration
-from airflow.utils import State, LoggingMixin
+from airflow.utils.state import State
+from airflow.utils.logging import LoggingMixin
PARALLELISM = configuration.getint('core', 'PARALLELISM')
diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index 5b3cd9da98818..088cb0b488ed6 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -6,7 +6,7 @@
from celery import Celery
from celery import states as celery_states
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow import configuration
diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py
index 19ada6a799946..15a89e169cd64 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -6,7 +6,8 @@
from airflow import configuration
from airflow.executors.base_executor import BaseExecutor
-from airflow.utils import State, LoggingMixin
+from airflow.utils.state import State
+from airflow.utils.logging import LoggingMixin
PARALLELISM = configuration.get('core', 'PARALLELISM')
diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py
index 4684226a1aa5f..53d9f0a626ea5 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -1,9 +1,8 @@
from builtins import str
-import logging
import subprocess
from airflow.executors.base_executor import BaseExecutor
-from airflow.utils import State
+from airflow.utils.state import State
class SequentialExecutor(BaseExecutor):
diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py
index 00b6a0cdcacdd..40ac1fb98f8ce 100644
--- a/airflow/hooks/S3_hook.py
+++ b/airflow/hooks/S3_hook.py
@@ -15,7 +15,7 @@
boto.set_stream_logger('boto')
logging.getLogger("boto").setLevel(logging.INFO)
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py
index d4c4c27141e2f..58fac177e84d8 100644
--- a/airflow/hooks/__init__.py
+++ b/airflow/hooks/__init__.py
@@ -1,6 +1,7 @@
# Imports the hooks dynamically while keeping the package API clean,
# abstracting the underlying modules
-from airflow.utils import import_module_attrs as _import_module_attrs
+
+from airflow.utils.helpers import import_module_attrs as _import_module_attrs
from airflow.hooks.base_hook import BaseHook # noqa to expose in package
_hooks = {
diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py
index 15fadad83e5b1..2a6cb734d7bda 100644
--- a/airflow/hooks/base_hook.py
+++ b/airflow/hooks/base_hook.py
@@ -10,7 +10,7 @@
from airflow import settings
from airflow.models import Connection
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
CONN_ENV_PREFIX = 'AIRFLOW_CONN_'
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index 1c98deaf1cb19..10c5acd0b94df 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -6,7 +6,7 @@
import logging
from airflow.hooks.base_hook import BaseHook
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
class DbApiHook(BaseHook):
diff --git a/airflow/hooks/druid_hook.py b/airflow/hooks/druid_hook.py
index 3c216d9d798a4..f01b1e39ffc19 100644
--- a/airflow/hooks/druid_hook.py
+++ b/airflow/hooks/druid_hook.py
@@ -7,7 +7,7 @@
import requests
from airflow.hooks.base_hook import BaseHook
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
LOAD_CHECK_INTERVAL = 5
diff --git a/airflow/hooks/hdfs_hook.py b/airflow/hooks/hdfs_hook.py
index f02cc7cffc8d8..3885bbd05c45a 100644
--- a/airflow/hooks/hdfs_hook.py
+++ b/airflow/hooks/hdfs_hook.py
@@ -7,7 +7,7 @@
except ImportError:
snakebite_imported = False
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
class HDFSHookException(AirflowException):
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index 6e3318abce82d..15d1c98b33057 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -22,9 +22,9 @@
import subprocess
from tempfile import NamedTemporaryFile
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
-from airflow.utils import TemporaryDirectory
+from airflow.utils.file import TemporaryDirectory
from airflow import configuration
import airflow.security.utils as utils
@@ -411,7 +411,7 @@ def table_exists(self, table_name, db='default'):
class HiveServer2Hook(BaseHook):
"""
- Wrapper around the impala library
+ Wrapper around the impyla library
Note that the default authMechanism is PLAIN, to override it you
can specify it in the ``extra`` of your connection in the UI as in
diff --git a/airflow/hooks/http_hook.py b/airflow/hooks/http_hook.py
index 4d0eb71790ddf..07cf9f264931a 100644
--- a/airflow/hooks/http_hook.py
+++ b/airflow/hooks/http_hook.py
@@ -4,7 +4,7 @@
import requests
from airflow.hooks.base_hook import BaseHook
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
class HttpHook(BaseHook):
diff --git a/airflow/hooks/pig_hook.py b/airflow/hooks/pig_hook.py
index 4d63fee7833a6..5b40e52536c39 100644
--- a/airflow/hooks/pig_hook.py
+++ b/airflow/hooks/pig_hook.py
@@ -3,9 +3,9 @@
import subprocess
from tempfile import NamedTemporaryFile
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
-from airflow.utils import TemporaryDirectory
+from airflow.utils.file import TemporaryDirectory
from airflow import configuration
diff --git a/airflow/hooks/webhdfs_hook.py b/airflow/hooks/webhdfs_hook.py
index 83e6eaa54b39c..79a23bc38cee2 100644
--- a/airflow/hooks/webhdfs_hook.py
+++ b/airflow/hooks/webhdfs_hook.py
@@ -11,7 +11,7 @@
except ImportError:
logging.error("Could not load the Kerberos extension for the WebHDFSHook.")
raise
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
class AirflowWebHDFSHookException(AirflowException):
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 5a939ed072d41..9ab47ba02c3c7 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -33,9 +33,14 @@
from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_
from sqlalchemy.orm.session import make_transient
-from airflow import executors, models, settings, utils
+from airflow import executors, models, settings
from airflow import configuration as conf
-from airflow.utils import AirflowException, State, LoggingMixin
+from airflow.exceptions import AirflowException
+from airflow.utils.state import State
+from airflow.utils.db import provide_session, pessimistic_connection_handling
+from airflow.utils.email import send_email
+from airflow.utils.logging import LoggingMixin
+from airflow.utils import asciiart
Base = models.Base
@@ -233,7 +238,7 @@ def __init__(
self.heartrate = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC')
- @utils.provide_session
+ @provide_session
def manage_slas(self, dag, session=None):
"""
Finding all tasks that have SLAs defined, and sending alert emails
@@ -322,12 +327,11 @@ def manage_slas(self, dag, session=None):
self.logger.info(' --------------> ABOUT TO CALL SLA MISS CALL BACK ')
dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis)
notification_sent = True
- from airflow import ascii
email_content = """\
Here's a list of tasks thas missed their SLAs:
{task_list}\n
Blocking tasks:
- {blocking_task_list}\n{ascii.bug}
+ {blocking_task_list}\n{asciiart.bug}
""".format(**locals())
emails = []
for t in dag.tasks:
@@ -340,7 +344,7 @@ def manage_slas(self, dag, session=None):
if email not in emails:
emails.append(email)
if emails and len(slas):
- utils.send_email(
+ send_email(
emails,
"[airflow] SLA miss on DAG=" + dag.dag_id,
email_content)
@@ -516,7 +520,7 @@ def process_dag(self, dag, executor):
session.close()
- @utils.provide_session
+ @provide_session
def prioritize_queued(self, session, executor, dagbag):
# Prioritizing queued task instances
@@ -608,7 +612,7 @@ def signal_handler(signum, frame):
sys.exit(1)
signal.signal(signal.SIGINT, signal_handler)
- utils.pessimistic_connection_handling()
+ pessimistic_connection_handling()
logging.basicConfig(level=logging.DEBUG)
self.logger.info("Starting the scheduler")
diff --git a/airflow/models.py b/airflow/models.py
index 69c6d930fc78c..0a4ceb45970eb 100644
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -53,9 +53,17 @@
from airflow import settings, utils
from airflow.executors import DEFAULT_EXECUTOR, LocalExecutor
from airflow import configuration
-from airflow.utils import (
- AirflowException, State, apply_defaults, provide_session,
- is_container, as_tuple, TriggerRule, LoggingMixin)
+from airflow.exceptions import AirflowException
+from airflow.utils.dates import cron_presets, date_range as utils_date_range
+from airflow.utils.db import provide_session
+from airflow.utils.decorators import apply_defaults
+from airflow.utils.email import send_email
+from airflow.utils.helpers import (as_tuple, is_container, is_in, validate_key)
+from airflow.utils.logging import LoggingMixin
+from airflow.utils.state import State
+from airflow.utils.timeout import timeout
+from airflow.utils.trigger_rule import TriggerRule
+
Base = declarative_base()
ID_LEN = 250
@@ -211,13 +219,13 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
return found_dags
if (not only_if_updated or
- filepath not in self.file_last_changed or
- dttm != self.file_last_changed[filepath]):
+ filepath not in self.file_last_changed or
+ dttm != self.file_last_changed[filepath]):
try:
self.logger.info("Importing " + filepath)
if mod_name in sys.modules:
del sys.modules[mod_name]
- with utils.timeout(
+ with timeout(
configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")):
m = imp.load_source(mod_name, filepath)
except Exception as e:
@@ -1067,7 +1075,7 @@ def signal_handler(signum, frame):
# if it goes beyond
result = None
if task_copy.execution_timeout:
- with utils.timeout(int(
+ with timeout(int(
task_copy.execution_timeout.total_seconds())):
result = task_copy.execute(context=context)
@@ -1245,7 +1253,7 @@ def email_alert(self, exception, is_retry=False):
"Log file: {self.log_filepath}
"
"Mark success: Link
"
).format(**locals())
- utils.send_email(task.email, title, body)
+ send_email(task.email, title, body)
def set_duration(self):
if self.end_date and self.start_date:
@@ -1531,7 +1539,7 @@ def __init__(
*args,
**kwargs):
- utils.validate_key(task_id)
+ validate_key(task_id)
self.dag_id = dag.dag_id if dag else 'adhoc_' + owner
self.task_id = task_id
self.owner = owner
@@ -1836,7 +1844,7 @@ def get_flat_relatives(self, upstream=False, l=None):
if not l:
l = []
for t in self.get_direct_relatives(upstream):
- if not utils.is_in(t, l):
+ if not is_in(t, l):
l.append(t)
t.get_flat_relatives(upstream, l)
return l
@@ -2100,14 +2108,14 @@ def __init__(
self.params.update(self.default_args['params'])
del self.default_args['params']
- utils.validate_key(dag_id)
+ validate_key(dag_id)
self.tasks = []
self.dag_id = dag_id
self.start_date = start_date
self.end_date = end_date
self.schedule_interval = schedule_interval
- if schedule_interval in utils.cron_presets:
- self._schedule_interval = utils.cron_presets.get(schedule_interval)
+ if schedule_interval in cron_presets:
+ self._schedule_interval = cron_presets.get(schedule_interval)
elif schedule_interval == '@once':
self._schedule_interval = None
else:
@@ -2164,7 +2172,7 @@ def __hash__(self):
def date_range(self, start_date, num=None, end_date=datetime.now()):
if num:
end_date = None
- return utils.date_range(
+ return utils_date_range(
start_date=start_date, end_date=end_date,
num=num, delta=self._schedule_interval)
@@ -2379,7 +2387,7 @@ def roots(self):
@provide_session
def set_dag_runs_state(
self, start_date, end_date, state=State.RUNNING, session=None):
- dates = utils.date_range(start_date, end_date)
+ dates = utils_date_range(start_date, end_date)
drs = session.query(DagModel).filter_by(dag_id=self.dag_id).all()
for dr in drs:
dr.state = State.RUNNING
@@ -2436,7 +2444,7 @@ def clear(
"You are about to delete these {count} tasks:\n"
"{ti_list}\n\n"
"Are you sure? (yes/no): ").format(**locals())
- do_it = utils.ask_yesno(question)
+ do_it = utils.helpers.ask_yesno(question)
if do_it:
clear_task_instances(tis, session)
@@ -2918,6 +2926,7 @@ def __repr__(self):
def id_for_date(klass, date, prefix=ID_FORMAT_PREFIX):
return prefix.format(date.isoformat()[:19])
+
class Pool(Base):
__tablename__ = "slot_pool"
diff --git a/airflow/operators/__init__.py b/airflow/operators/__init__.py
index 026ee3b13e643..daccbd165a066 100644
--- a/airflow/operators/__init__.py
+++ b/airflow/operators/__init__.py
@@ -1,6 +1,6 @@
# Imports operators dynamically while keeping the package API clean,
# abstracting the underlying modules
-from airflow.utils import import_module_attrs as _import_module_attrs
+from airflow.utils.helpers import import_module_attrs as _import_module_attrs
# These need to be integrated first as other operators depend on them
_import_module_attrs(globals(), {
diff --git a/airflow/operators/bash_operator.py b/airflow/operators/bash_operator.py
index e6925ec3f8cc7..cae53b36cb5d3 100644
--- a/airflow/operators/bash_operator.py
+++ b/airflow/operators/bash_operator.py
@@ -4,9 +4,10 @@
from subprocess import Popen, STDOUT, PIPE
from tempfile import gettempdir, NamedTemporaryFile
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults, TemporaryDirectory
+from airflow.utils.decorators import apply_defaults
+from airflow.utils.file import TemporaryDirectory
class BashOperator(BaseOperator):
diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
index fc441b30854b1..0624d915041d3 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/check_operator.py
@@ -2,10 +2,10 @@
from builtins import str
import logging
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.hooks import BaseHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class CheckOperator(BaseOperator):
diff --git a/airflow/operators/dagrun_operator.py b/airflow/operators/dagrun_operator.py
index 5d397aad0688e..7f8bb53400ad7 100644
--- a/airflow/operators/dagrun_operator.py
+++ b/airflow/operators/dagrun_operator.py
@@ -2,7 +2,7 @@
import logging
from airflow.models import BaseOperator, DagRun
-from airflow.utils import apply_defaults, State
+from airflow.utils.decorators import apply_defaults
from airflow import settings
diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py
index c7b1df7f7b6a7..b01d31ac51042 100644
--- a/airflow/operators/docker_operator.py
+++ b/airflow/operators/docker_operator.py
@@ -1,7 +1,9 @@
import json
import logging
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults, AirflowException, TemporaryDirectory
+from airflow.utils.decorators import apply_defaults
+from airflow.utils.file import TemporaryDirectory
from docker import Client, tls
import ast
diff --git a/airflow/operators/dummy_operator.py b/airflow/operators/dummy_operator.py
index 6b69115e6b2a9..1392e7d33cc98 100644
--- a/airflow/operators/dummy_operator.py
+++ b/airflow/operators/dummy_operator.py
@@ -1,5 +1,5 @@
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class DummyOperator(BaseOperator):
diff --git a/airflow/operators/email_operator.py b/airflow/operators/email_operator.py
index 0cfff08a4bfef..29b18edad0b51 100644
--- a/airflow/operators/email_operator.py
+++ b/airflow/operators/email_operator.py
@@ -1,6 +1,6 @@
from airflow.models import BaseOperator
-from airflow.utils import send_email
-from airflow.utils import apply_defaults
+from airflow.utils.email import send_email
+from airflow.utils.decorators import apply_defaults
class EmailOperator(BaseOperator):
diff --git a/airflow/operators/generic_transfer.py b/airflow/operators/generic_transfer.py
index 7e99d3e334f65..eab9d61c0f224 100644
--- a/airflow/operators/generic_transfer.py
+++ b/airflow/operators/generic_transfer.py
@@ -1,7 +1,7 @@
import logging
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
from airflow.hooks.base_hook import BaseHook
diff --git a/airflow/operators/hive_operator.py b/airflow/operators/hive_operator.py
index 7c0d299e54d51..9a299e1e02160 100644
--- a/airflow/operators/hive_operator.py
+++ b/airflow/operators/hive_operator.py
@@ -3,7 +3,7 @@
from airflow.hooks import HiveCliHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class HiveOperator(BaseOperator):
diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py
index 09f85e17af105..aadca4de28755 100644
--- a/airflow/operators/hive_stats_operator.py
+++ b/airflow/operators/hive_stats_operator.py
@@ -4,10 +4,10 @@
import json
import logging
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.hooks import PrestoHook, HiveMetastoreHook, MySqlHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class HiveStatsCollectionOperator(BaseOperator):
diff --git a/airflow/operators/hive_to_druid.py b/airflow/operators/hive_to_druid.py
index e518ea430ad2e..1346841e6f7a3 100644
--- a/airflow/operators/hive_to_druid.py
+++ b/airflow/operators/hive_to_druid.py
@@ -2,7 +2,7 @@
from airflow.hooks import HiveCliHook, DruidHook, HiveMetastoreHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class HiveToDruidTransfer(BaseOperator):
diff --git a/airflow/operators/hive_to_mysql.py b/airflow/operators/hive_to_mysql.py
index bfbe330cfa278..9e27f38516ab5 100644
--- a/airflow/operators/hive_to_mysql.py
+++ b/airflow/operators/hive_to_mysql.py
@@ -2,7 +2,7 @@
from airflow.hooks import HiveServer2Hook, MySqlHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
from tempfile import NamedTemporaryFile
diff --git a/airflow/operators/hive_to_samba_operator.py b/airflow/operators/hive_to_samba_operator.py
index cfa98142ffefd..63881ab981097 100644
--- a/airflow/operators/hive_to_samba_operator.py
+++ b/airflow/operators/hive_to_samba_operator.py
@@ -3,7 +3,7 @@
from airflow.hooks import HiveServer2Hook, SambaHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class Hive2SambaOperator(BaseOperator):
diff --git a/airflow/operators/http_operator.py b/airflow/operators/http_operator.py
index a9b2ad5e5ee67..87d1415bf625b 100644
--- a/airflow/operators/http_operator.py
+++ b/airflow/operators/http_operator.py
@@ -1,8 +1,9 @@
import logging
+from airflow.exceptions import AirflowException
from airflow.hooks import HttpHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults, AirflowException
+from airflow.utils.decorators import apply_defaults
class SimpleHttpOperator(BaseOperator):
diff --git a/airflow/operators/jdbc_operator.py b/airflow/operators/jdbc_operator.py
index 8793045fba675..5efdaf4e6ba84 100644
--- a/airflow/operators/jdbc_operator.py
+++ b/airflow/operators/jdbc_operator.py
@@ -4,7 +4,7 @@
from airflow.hooks.jdbc_hook import JdbcHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class JdbcOperator(BaseOperator):
diff --git a/airflow/operators/mssql_operator.py b/airflow/operators/mssql_operator.py
index 3dec7cebaf619..1d5273a49105b 100644
--- a/airflow/operators/mssql_operator.py
+++ b/airflow/operators/mssql_operator.py
@@ -2,7 +2,7 @@
from airflow.hooks import MsSqlHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class MsSqlOperator(BaseOperator):
diff --git a/airflow/operators/mssql_to_hive.py b/airflow/operators/mssql_to_hive.py
index 60586de7a3db8..6a981b43c8d97 100644
--- a/airflow/operators/mssql_to_hive.py
+++ b/airflow/operators/mssql_to_hive.py
@@ -8,7 +8,7 @@
from airflow.hooks import HiveCliHook, MsSqlHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class MsSqlToHiveTransfer(BaseOperator):
diff --git a/airflow/operators/mysql_operator.py b/airflow/operators/mysql_operator.py
index b8d56d5097e38..ae6d36f3278af 100644
--- a/airflow/operators/mysql_operator.py
+++ b/airflow/operators/mysql_operator.py
@@ -2,7 +2,7 @@
from airflow.hooks import MySqlHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class MySqlOperator(BaseOperator):
diff --git a/airflow/operators/mysql_to_hive.py b/airflow/operators/mysql_to_hive.py
index 6e2a8dd58b242..09ec190f77458 100644
--- a/airflow/operators/mysql_to_hive.py
+++ b/airflow/operators/mysql_to_hive.py
@@ -7,7 +7,7 @@
from airflow.hooks import HiveCliHook, MySqlHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class MySqlToHiveTransfer(BaseOperator):
diff --git a/airflow/operators/pig_operator.py b/airflow/operators/pig_operator.py
index e0d91afd33067..d25795dec73d7 100644
--- a/airflow/operators/pig_operator.py
+++ b/airflow/operators/pig_operator.py
@@ -3,7 +3,7 @@
from airflow.hooks import PigCliHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class PigOperator(BaseOperator):
diff --git a/airflow/operators/postgres_operator.py b/airflow/operators/postgres_operator.py
index a7302a050b5f0..79fa5e75330de 100644
--- a/airflow/operators/postgres_operator.py
+++ b/airflow/operators/postgres_operator.py
@@ -2,7 +2,7 @@
from airflow.hooks import PostgresHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class PostgresOperator(BaseOperator):
diff --git a/airflow/operators/presto_check_operator.py b/airflow/operators/presto_check_operator.py
index 9228a93a544c3..e857036415e6e 100644
--- a/airflow/operators/presto_check_operator.py
+++ b/airflow/operators/presto_check_operator.py
@@ -1,6 +1,6 @@
from airflow.hooks import PrestoHook
from airflow.operators import CheckOperator, ValueCheckOperator, IntervalCheckOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class PrestoCheckOperator(CheckOperator):
diff --git a/airflow/operators/presto_to_mysql.py b/airflow/operators/presto_to_mysql.py
index 37c3caadcb8c5..29de0c7d86655 100644
--- a/airflow/operators/presto_to_mysql.py
+++ b/airflow/operators/presto_to_mysql.py
@@ -2,7 +2,7 @@
from airflow.hooks import PrestoHook, MySqlHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class PrestoToMySqlTransfer(BaseOperator):
diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py
index 1cf7eed8e5a56..290cc65d139e9 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -3,7 +3,8 @@
import logging
from airflow.models import BaseOperator, TaskInstance
-from airflow.utils import apply_defaults, State
+from airflow.utils.state import State
+from airflow.utils.decorators import apply_defaults
from airflow import settings
diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py
index 837c2f902434c..ce36b00efc57d 100644
--- a/airflow/operators/s3_file_transform_operator.py
+++ b/airflow/operators/s3_file_transform_operator.py
@@ -2,10 +2,10 @@
from tempfile import NamedTemporaryFile
import subprocess
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.hooks import S3Hook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class S3FileTransformOperator(BaseOperator):
diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py
index 20d23e781e3d8..3fc5327a40743 100644
--- a/airflow/operators/s3_to_hive_operator.py
+++ b/airflow/operators/s3_to_hive_operator.py
@@ -3,10 +3,10 @@
import logging
from tempfile import NamedTemporaryFile
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.hooks import HiveCliHook, S3Hook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class S3ToHiveTransfer(BaseOperator):
diff --git a/airflow/operators/sensors.py b/airflow/operators/sensors.py
index 4c45962db0208..7d7a7c689ce6d 100644
--- a/airflow/operators/sensors.py
+++ b/airflow/operators/sensors.py
@@ -8,11 +8,11 @@
from time import sleep
from airflow import hooks, settings
+from airflow.exceptions import AirflowException, AirflowSensorTimeout
from airflow.models import BaseOperator, TaskInstance, Connection as DB
from airflow.hooks import BaseHook
-from airflow.utils import State
-from airflow.utils import (
- apply_defaults, AirflowException, AirflowSensorTimeout)
+from airflow.utils.state import State
+from airflow.utils.decorators import apply_defaults
class BaseSensorOperator(BaseOperator):
diff --git a/airflow/operators/slack_operator.py b/airflow/operators/slack_operator.py
index c8734fb954431..2f173d7edf34b 100644
--- a/airflow/operators/slack_operator.py
+++ b/airflow/operators/slack_operator.py
@@ -1,6 +1,7 @@
from slackclient import SlackClient
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults, AirflowException
+from airflow.utils.decorators import apply_defaults
+from airflow.exceptions import AirflowException
import json
import logging
diff --git a/airflow/operators/sqlite_operator.py b/airflow/operators/sqlite_operator.py
index ebdba2f5ce725..700019d9ead8b 100644
--- a/airflow/operators/sqlite_operator.py
+++ b/airflow/operators/sqlite_operator.py
@@ -2,7 +2,7 @@
from airflow.hooks import SqliteHook
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
class SqliteOperator(BaseOperator):
diff --git a/airflow/operators/subdag_operator.py b/airflow/operators/subdag_operator.py
index 54c2409d7e73b..c56e7afc54066 100644
--- a/airflow/operators/subdag_operator.py
+++ b/airflow/operators/subdag_operator.py
@@ -1,6 +1,6 @@
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
-from airflow.utils import apply_defaults
+from airflow.utils.decorators import apply_defaults
from airflow.executors import DEFAULT_EXECUTOR
diff --git a/airflow/settings.py b/airflow/settings.py
index 51dfe4d153717..ae56455649fe5 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -109,6 +109,7 @@ def policy(task_instance):
"""
pass
+
def configure_logging():
logging.root.handlers = []
logging.basicConfig(
diff --git a/airflow/utils.py b/airflow/utils.py
deleted file mode 100644
index 228602e0d8888..0000000000000
--- a/airflow/utils.py
+++ /dev/null
@@ -1,978 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-
-import sys
-from builtins import str, input, object
-from past.builtins import basestring
-from copy import copy
-from datetime import datetime, date, timedelta
-from dateutil.relativedelta import relativedelta # for doctest
-from email.mime.text import MIMEText
-from email.mime.multipart import MIMEMultipart
-from email.mime.application import MIMEApplication
-from email.utils import formatdate
-import errno
-from functools import wraps
-import imp
-import importlib
-import inspect
-import json
-import logging
-import os
-import re
-import shutil
-import signal
-import six
-import smtplib
-from tempfile import mkdtemp
-
-from alembic.config import Config
-from alembic import command
-from alembic.migration import MigrationContext
-
-from contextlib import contextmanager
-
-from sqlalchemy import event, exc
-from sqlalchemy.pool import Pool
-
-import numpy as np
-from croniter import croniter
-
-from airflow import settings
-from airflow import configuration
-
-
-class AirflowException(Exception):
- pass
-
-
-class AirflowSensorTimeout(Exception):
- pass
-
-
-class TriggerRule(object):
- ALL_SUCCESS = 'all_success'
- ALL_FAILED = 'all_failed'
- ALL_DONE = 'all_done'
- ONE_SUCCESS = 'one_success'
- ONE_FAILED = 'one_failed'
- DUMMY = 'dummy'
-
- @classmethod
- def is_valid(cls, trigger_rule):
- return trigger_rule in cls.all_triggers()
-
- @classmethod
- def all_triggers(cls):
- return [getattr(cls, attr)
- for attr in dir(cls)
- if not attr.startswith("__") and not callable(getattr(cls, attr))]
-
-
-class State(object):
- """
- Static class with task instance states constants and color method to
- avoid hardcoding.
- """
- QUEUED = "queued"
- RUNNING = "running"
- SUCCESS = "success"
- SHUTDOWN = "shutdown" # External request to shut down
- FAILED = "failed"
- UP_FOR_RETRY = "up_for_retry"
- UPSTREAM_FAILED = "upstream_failed"
- SKIPPED = "skipped"
-
- state_color = {
- QUEUED: 'gray',
- RUNNING: 'lime',
- SUCCESS: 'green',
- SHUTDOWN: 'blue',
- FAILED: 'red',
- UP_FOR_RETRY: 'gold',
- UPSTREAM_FAILED: 'orange',
- SKIPPED: 'pink',
- }
-
- @classmethod
- def color(cls, state):
- if state in cls.state_color:
- return cls.state_color[state]
- else:
- return 'white'
-
- @classmethod
- def color_fg(cls, state):
- color = cls.color(state)
- if color in ['green', 'red']:
- return 'white'
- else:
- return 'black'
-
- @classmethod
- def runnable(cls):
- return [
- None, cls.FAILED, cls.UP_FOR_RETRY, cls.UPSTREAM_FAILED,
- cls.SKIPPED, cls.QUEUED]
-
-
-cron_presets = {
- '@hourly': '0 * * * *',
- '@daily': '0 0 * * *',
- '@weekly': '0 0 * * 0',
- '@monthly': '0 0 1 * *',
- '@yearly': '0 0 1 1 *',
-}
-
-def provide_session(func):
- """
- Function decorator that provides a session if it isn't provided.
- If you want to reuse a session or run the function as part of a
- database transaction, you pass it to the function, if not this wrapper
- will create one and close it for you.
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
- needs_session = False
- if 'session' not in kwargs:
- needs_session = True
- session = settings.Session()
- kwargs['session'] = session
- result = func(*args, **kwargs)
- if needs_session:
- session.expunge_all()
- session.commit()
- session.close()
- return result
- return wrapper
-
-
-def pessimistic_connection_handling():
- @event.listens_for(Pool, "checkout")
- def ping_connection(dbapi_connection, connection_record, connection_proxy):
- '''
- Disconnect Handling - Pessimistic, taken from:
- http://docs.sqlalchemy.org/en/rel_0_9/core/pooling.html
- '''
- cursor = dbapi_connection.cursor()
- try:
- cursor.execute("SELECT 1")
- except:
- raise exc.DisconnectionError()
- cursor.close()
-
-@provide_session
-def merge_conn(conn, session=None):
- from airflow import models
- C = models.Connection
- if not session.query(C).filter(C.conn_id == conn.conn_id).first():
- session.add(conn)
- session.commit()
-
-
-def initdb():
- session = settings.Session()
-
- from airflow import models
- upgradedb()
-
- merge_conn(
- models.Connection(
- conn_id='airflow_db', conn_type='mysql',
- host='localhost', login='root',
- schema='airflow'))
- merge_conn(
- models.Connection(
- conn_id='airflow_ci', conn_type='mysql',
- host='localhost', login='root',
- schema='airflow_ci'))
- merge_conn(
- models.Connection(
- conn_id='beeline_default', conn_type='beeline', port="10000",
- host='localhost', extra="{\"use_beeline\": true, \"auth\": \"\"}",
- schema='default'))
- merge_conn(
- models.Connection(
- conn_id='bigquery_default', conn_type='bigquery'))
- merge_conn(
- models.Connection(
- conn_id='local_mysql', conn_type='mysql',
- host='localhost', login='airflow', password='airflow',
- schema='airflow'))
- merge_conn(
- models.Connection(
- conn_id='presto_default', conn_type='presto',
- host='localhost',
- schema='hive', port=3400))
- merge_conn(
- models.Connection(
- conn_id='hive_cli_default', conn_type='hive_cli',
- schema='default',))
- merge_conn(
- models.Connection(
- conn_id='hiveserver2_default', conn_type='hiveserver2',
- host='localhost',
- schema='default', port=10000))
- merge_conn(
- models.Connection(
- conn_id='metastore_default', conn_type='hive_metastore',
- host='localhost', extra="{\"authMechanism\": \"PLAIN\"}",
- port=9083))
- merge_conn(
- models.Connection(
- conn_id='mysql_default', conn_type='mysql',
- login='root',
- host='localhost'))
- merge_conn(
- models.Connection(
- conn_id='postgres_default', conn_type='postgres',
- login='postgres',
- schema='airflow',
- host='localhost'))
- merge_conn(
- models.Connection(
- conn_id='sqlite_default', conn_type='sqlite',
- host='/tmp/sqlite_default.db'))
- merge_conn(
- models.Connection(
- conn_id='http_default', conn_type='http',
- host='https://www.google.com/'))
- merge_conn(
- models.Connection(
- conn_id='mssql_default', conn_type='mssql',
- host='localhost', port=1433))
- merge_conn(
- models.Connection(
- conn_id='vertica_default', conn_type='vertica',
- host='localhost', port=5433))
- merge_conn(
- models.Connection(
- conn_id='webhdfs_default', conn_type='hdfs',
- host='localhost', port=50070))
- merge_conn(
- models.Connection(
- conn_id='ssh_default', conn_type='ssh',
- host='localhost'))
-
- # Known event types
- KET = models.KnownEventType
- if not session.query(KET).filter(KET.know_event_type == 'Holiday').first():
- session.add(KET(know_event_type='Holiday'))
- if not session.query(KET).filter(KET.know_event_type == 'Outage').first():
- session.add(KET(know_event_type='Outage'))
- if not session.query(KET).filter(
- KET.know_event_type == 'Natural Disaster').first():
- session.add(KET(know_event_type='Natural Disaster'))
- if not session.query(KET).filter(
- KET.know_event_type == 'Marketing Campaign').first():
- session.add(KET(know_event_type='Marketing Campaign'))
- session.commit()
-
- models.DagBag(sync_to_db=True)
-
- Chart = models.Chart
- chart_label = "Airflow task instance by type"
- chart = session.query(Chart).filter(Chart.label == chart_label).first()
- if not chart:
- chart = Chart(
- label=chart_label,
- conn_id='airflow_db',
- chart_type='bar',
- x_is_date=False,
- sql=(
- "SELECT state, COUNT(1) as number "
- "FROM task_instance "
- "WHERE dag_id LIKE 'example%' "
- "GROUP BY state"),
- )
- session.add(chart)
-
-
-def upgradedb():
- logging.info("Creating tables")
- package_dir = os.path.abspath(os.path.dirname(__file__))
- directory = os.path.join(package_dir, 'migrations')
- config = Config(os.path.join(package_dir, 'alembic.ini'))
- config.set_main_option('script_location', directory)
- config.set_main_option('sqlalchemy.url',
- configuration.get('core', 'SQL_ALCHEMY_CONN'))
- command.upgrade(config, 'heads')
-
-
-def resetdb():
- '''
- Clear out the database
- '''
- from airflow import models
-
- logging.info("Dropping tables that exist")
- models.Base.metadata.drop_all(settings.engine)
- mc = MigrationContext.configure(settings.engine)
- if mc._version.exists(settings.engine):
- mc._version.drop(settings.engine)
- initdb()
-
-
-def validate_key(k, max_length=250):
- if not isinstance(k, basestring):
- raise TypeError("The key has to be a string")
- elif len(k) > max_length:
- raise AirflowException(
- "The key has to be less than {0} characters".format(max_length))
- elif not re.match(r'^[A-Za-z0-9_\-\.]+$', k):
- raise AirflowException(
- "The key ({k}) has to be made of alphanumeric characters, dashes, "
- "dots and underscores exclusively".format(**locals()))
- else:
- return True
-
-
-def date_range(
- start_date,
- end_date=None,
- num=None,
- delta=None):
- """
- Get a set of dates as a list based on a start, end and delta, delta
- can be something that can be added to ``datetime.datetime``
- or a cron expression as a ``str``
-
- :param start_date: anchor date to start the series from
- :type start_date: datetime.datetime
- :param end_date: right boundary for the date range
- :type end_date: datetime.datetime
- :param num: alternatively to end_date, you can specify the number of
- number of entries you want in the range. This number can be negative,
- output will always be sorted regardless
- :type num: int
-
- >>> date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta=timedelta(1))
- [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)]
- >>> date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta='0 0 * * *')
- [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)]
- >>> date_range(datetime(2016, 1, 1), datetime(2016, 3, 3), delta="0 0 0 * *")
- [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 2, 1, 0, 0), datetime.datetime(2016, 3, 1, 0, 0)]
- """
- if not delta:
- return []
- if end_date and start_date > end_date:
- raise Exception("Wait. start_date needs to be before end_date")
- if end_date and num:
- raise Exception("Wait. Either specify end_date OR num")
- if not end_date and not num:
- end_date = datetime.now()
-
- delta_iscron = False
- if isinstance(delta, six.string_types):
- delta_iscron = True
- cron = croniter(delta, start_date)
- elif isinstance(delta, timedelta):
- delta = abs(delta)
- l = []
- if end_date:
- while start_date <= end_date:
- l.append(start_date)
- if delta_iscron:
- start_date = cron.get_next(datetime)
- else:
- start_date += delta
- else:
- for i in range(abs(num)):
- l.append(start_date)
- if delta_iscron:
- if num > 0:
- start_date = cron.get_next(datetime)
- else:
- start_date = cron.get_prev(datetime)
- else:
- if num > 0:
- start_date += delta
- else:
- start_date -= delta
- return sorted(l)
-
-
-def json_ser(obj):
- """
- json serializer that deals with dates
- usage: json.dumps(object, default=utils.json_ser)
- """
- if isinstance(obj, (datetime, date)):
- return obj.isoformat()
-
-
-def alchemy_to_dict(obj):
- """
- Transforms a SQLAlchemy model instance into a dictionary
- """
- if not obj:
- return None
- d = {}
- for c in obj.__table__.columns:
- value = getattr(obj, c.name)
- if type(value) == datetime:
- value = value.isoformat()
- d[c.name] = value
- return d
-
-
-def readfile(filepath):
- f = open(filepath)
- content = f.read()
- f.close()
- return content
-
-
-def apply_defaults(func):
- """
- Function decorator that Looks for an argument named "default_args", and
- fills the unspecified arguments from it.
-
- Since python2.* isn't clear about which arguments are missing when
- calling a function, and that this can be quite confusing with multi-level
- inheritance and argument defaults, this decorator also alerts with
- specific information about the missing arguments.
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
- if len(args) > 1:
- raise AirflowException(
- "Use keyword arguments when initializing operators")
- dag_args = {}
- dag_params = {}
- if 'dag' in kwargs and kwargs['dag']:
- dag = kwargs['dag']
- dag_args = copy(dag.default_args) or {}
- dag_params = copy(dag.params) or {}
-
- params = {}
- if 'params' in kwargs:
- params = kwargs['params']
- dag_params.update(params)
-
- default_args = {}
- if 'default_args' in kwargs:
- default_args = kwargs['default_args']
- if 'params' in default_args:
- dag_params.update(default_args['params'])
- del default_args['params']
-
- dag_args.update(default_args)
- default_args = dag_args
- arg_spec = inspect.getargspec(func)
- num_defaults = len(arg_spec.defaults) if arg_spec.defaults else 0
- non_optional_args = arg_spec.args[:-num_defaults]
- if 'self' in non_optional_args:
- non_optional_args.remove('self')
- for arg in func.__code__.co_varnames:
- if arg in default_args and arg not in kwargs:
- kwargs[arg] = default_args[arg]
- missing_args = list(set(non_optional_args) - set(kwargs))
- if missing_args:
- msg = "Argument {0} is required".format(missing_args)
- raise AirflowException(msg)
-
- kwargs['params'] = dag_params
-
- result = func(*args, **kwargs)
- return result
- return wrapper
-
-if 'BUILDING_AIRFLOW_DOCS' in os.environ:
- # Monkey patch hook to get good function headers while building docs
- apply_defaults = lambda x: x
-
-def ask_yesno(question):
- yes = set(['yes', 'y'])
- no = set(['no', 'n'])
-
- done = False
- print(question)
- while not done:
- choice = input().lower()
- if choice in yes:
- return True
- elif choice in no:
- return False
- else:
- print("Please respond by yes or no.")
-
-
-def send_email(to, subject, html_content, files=None, dryrun=False):
- """
- Send email using backend specified in EMAIL_BACKEND.
- """
- path, attr = configuration.get('email', 'EMAIL_BACKEND').rsplit('.', 1)
- module = importlib.import_module(path)
- backend = getattr(module, attr)
- return backend(to, subject, html_content, files=files, dryrun=dryrun)
-
-
-def send_email_smtp(to, subject, html_content, files=None, dryrun=False):
- """
- Send an email with html content
-
- >>> send_email('test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True)
- """
- SMTP_MAIL_FROM = configuration.get('smtp', 'SMTP_MAIL_FROM')
-
- if isinstance(to, basestring):
- if ',' in to:
- to = to.split(',')
- elif ';' in to:
- to = to.split(';')
- else:
- to = [to]
-
- msg = MIMEMultipart('alternative')
- msg['Subject'] = subject
- msg['From'] = SMTP_MAIL_FROM
- msg['To'] = ", ".join(to)
- msg["Date"] = formatdate(localtime=True)
- mime_text = MIMEText(html_content, 'html')
- msg.attach(mime_text)
-
- for fname in files or []:
- basename = os.path.basename(fname)
- with open(fname, "rb") as f:
- msg.attach(MIMEApplication(
- f.read(),
- Content_Disposition='attachment; filename="%s"' % basename,
- Name=basename
- ))
-
- send_MIME_email(SMTP_MAIL_FROM, to, msg, dryrun)
-
-
-def send_MIME_email(e_from, e_to, mime_msg, dryrun=False):
- SMTP_HOST = configuration.get('smtp', 'SMTP_HOST')
- SMTP_PORT = configuration.getint('smtp', 'SMTP_PORT')
- SMTP_USER = configuration.get('smtp', 'SMTP_USER')
- SMTP_PASSWORD = configuration.get('smtp', 'SMTP_PASSWORD')
- SMTP_STARTTLS = configuration.getboolean('smtp', 'SMTP_STARTTLS')
- SMTP_SSL = configuration.getboolean('smtp', 'SMTP_SSL')
-
- if not dryrun:
- s = smtplib.SMTP_SSL(SMTP_HOST, SMTP_PORT) if SMTP_SSL else smtplib.SMTP(SMTP_HOST, SMTP_PORT)
- if SMTP_STARTTLS:
- s.starttls()
- if SMTP_USER and SMTP_PASSWORD:
- s.login(SMTP_USER, SMTP_PASSWORD)
- logging.info("Sent an alert email to " + str(e_to))
- s.sendmail(e_from, e_to, mime_msg.as_string())
- s.quit()
-
-
-def import_module_attrs(parent_module_globals, module_attrs_dict):
- '''
- Attempts to import a set of modules and specified attributes in the
- form of a dictionary. The attributes are copied in the parent module's
- namespace. The function returns a list of attributes names that can be
- affected to __all__.
-
- This is used in the context of ``operators`` and ``hooks`` and
- silence the import errors for when libraries are missing. It makes
- for a clean package abstracting the underlying modules and only
- brings functional operators to those namespaces.
- '''
- imported_attrs = []
- for mod, attrs in list(module_attrs_dict.items()):
- try:
- path = os.path.realpath(parent_module_globals['__file__'])
- folder = os.path.dirname(path)
- f, filename, description = imp.find_module(mod, [folder])
- module = imp.load_module(mod, f, filename, description)
- for attr in attrs:
- parent_module_globals[attr] = getattr(module, attr)
- imported_attrs += [attr]
- except Exception as err:
- logging.debug("Error importing module {mod}: {err}".format(
- mod=mod, err=err))
- return imported_attrs
-
-
-def is_in(obj, l):
- """
- Checks whether an object is one of the item in the list.
- This is different from ``in`` because ``in`` uses __cmp__ when
- present. Here we change based on the object itself
- """
- for item in l:
- if item is obj:
- return True
- return False
-
-
-@contextmanager
-def TemporaryDirectory(suffix='', prefix=None, dir=None):
- name = mkdtemp(suffix=suffix, prefix=prefix, dir=dir)
- try:
- yield name
- finally:
- try:
- shutil.rmtree(name)
- except OSError as e:
- # ENOENT - no such file or directory
- if e.errno != errno.ENOENT:
- raise e
-
-
-class AirflowTaskTimeout(Exception):
- pass
-
-
-class timeout(object):
- """
- To be used in a ``with`` block and timeout its content.
- """
- def __init__(self, seconds=1, error_message='Timeout'):
- self.seconds = seconds
- self.error_message = error_message
-
- def handle_timeout(self, signum, frame):
- logging.error("Process timed out")
- raise AirflowTaskTimeout(self.error_message)
-
- def __enter__(self):
- try:
- signal.signal(signal.SIGALRM, self.handle_timeout)
- signal.alarm(self.seconds)
- except ValueError as e:
- logging.warning("timeout can't be used in the current context")
- logging.exception(e)
-
- def __exit__(self, type, value, traceback):
- try:
- signal.alarm(0)
- except ValueError as e:
- logging.warning("timeout can't be used in the current context")
- logging.exception(e)
-
-
-def is_container(obj):
- """
- Test if an object is a container (iterable) but not a string
- """
- return hasattr(obj, '__iter__') and not isinstance(obj, basestring)
-
-
-def as_tuple(obj):
- """
- If obj is a container, returns obj as a tuple.
- Otherwise, returns a tuple containing obj.
- """
- if is_container(obj):
- return tuple(obj)
- else:
- return tuple([obj])
-
-
-def round_time(dt, delta, start_date=datetime.min):
- """
- Returns the datetime of the form start_date + i * delta
- which is closest to dt for any non-negative integer i.
-
- Note that delta may be a datetime.timedelta or a dateutil.relativedelta
-
- >>> round_time(datetime(2015, 1, 1, 6), timedelta(days=1))
- datetime.datetime(2015, 1, 1, 0, 0)
- >>> round_time(datetime(2015, 1, 2), relativedelta(months=1))
- datetime.datetime(2015, 1, 1, 0, 0)
- >>> round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
- datetime.datetime(2015, 9, 16, 0, 0)
- >>> round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
- datetime.datetime(2015, 9, 15, 0, 0)
- >>> round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
- datetime.datetime(2015, 9, 14, 0, 0)
- >>> round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
- datetime.datetime(2015, 9, 14, 0, 0)
- """
-
- if isinstance(delta, six.string_types):
- # It's cron based, so it's easy
- cron = croniter(delta, start_date)
- prev = cron.get_prev(datetime)
- if prev == start_date:
- return start_date
- else:
- return prev
-
- # Ignore the microseconds of dt
- dt -= timedelta(microseconds = dt.microsecond)
-
- # We are looking for a datetime in the form start_date + i * delta
- # which is as close as possible to dt. Since delta could be a relative
- # delta we don't know it's exact length in seconds so we cannot rely on
- # division to find i. Instead we employ a binary search algorithm, first
- # finding an upper and lower limit and then disecting the interval until
- # we have found the closest match.
-
- # We first search an upper limit for i for which start_date + upper * delta
- # exceeds dt.
- upper = 1
- while start_date + upper*delta < dt:
- # To speed up finding an upper limit we grow this exponentially by a
- # factor of 2
- upper *= 2
-
- # Since upper is the first value for which start_date + upper * delta
- # exceeds dt, upper // 2 is below dt and therefore forms a lower limited
- # for the i we are looking for
- lower = upper // 2
-
- # We now continue to intersect the interval between
- # start_date + lower * delta and start_date + upper * delta
- # until we find the closest value
- while True:
- # Invariant: start + lower * delta < dt <= start + upper * delta
- # If start_date + (lower + 1)*delta exceeds dt, then either lower or
- # lower+1 has to be the solution we are searching for
- if start_date + (lower + 1)*delta >= dt:
- # Check if start_date + (lower + 1)*delta or
- # start_date + lower*delta is closer to dt and return the solution
- if (
- (start_date + (lower + 1) * delta) - dt <=
- dt - (start_date + lower * delta)):
- return start_date + (lower + 1)*delta
- else:
- return start_date + lower * delta
-
- # We intersect the interval and either replace the lower or upper
- # limit with the candidate
- candidate = lower + (upper - lower) // 2
- if start_date + candidate*delta >= dt:
- upper = candidate
- else:
- lower = candidate
-
- # in the special case when start_date > dt the search for upper will
- # immediately stop for upper == 1 which results in lower = upper // 2 = 0
- # and this function returns start_date.
-
-
-def chain(*tasks):
- """
- Given a number of tasks, builds a dependency chain.
-
- chain(task_1, task_2, task_3, task_4)
-
- is equivalent to
-
- task_1.set_downstream(task_2)
- task_2.set_downstream(task_3)
- task_3.set_downstream(task_4)
- """
- for up_task, down_task in zip(tasks[:-1], tasks[1:]):
- up_task.set_downstream(down_task)
-
-
-class AirflowJsonEncoder(json.JSONEncoder):
- def default(self, obj):
- # convert dates and numpy objects in a json serializable format
- if isinstance(obj, datetime):
- return obj.strftime('%Y-%m-%dT%H:%M:%SZ')
- elif isinstance(obj, date):
- return obj.strftime('%Y-%m-%d')
- elif type(obj) in [np.int_, np.intc, np.intp, np.int8, np.int16,
- np.int32, np.int64, np.uint8, np.uint16,
- np.uint32, np.uint64]:
- return int(obj)
- elif type(obj) in [np.bool_]:
- return bool(obj)
- elif type(obj) in [np.float_, np.float16, np.float32, np.float64,
- np.complex_, np.complex64, np.complex128]:
- return float(obj)
-
- # Let the base class default method raise the TypeError
- return json.JSONEncoder.default(self, obj)
-
-
-class LoggingMixin(object):
- """
- Convenience super-class to have a logger configured with the class name
- """
-
- @property
- def logger(self):
- try:
- return self._logger
- except AttributeError:
- self._logger = logging.root.getChild(self.__class__.__module__ + '.' +self.__class__.__name__)
- return self._logger
-
-
-class S3Log(object):
- """
- Utility class for reading and writing logs in S3.
- Requires airflow[s3] and setting the REMOTE_BASE_LOG_FOLDER and
- REMOTE_LOG_CONN_ID configuration options in airflow.cfg.
- """
- def __init__(self):
- remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID')
- try:
- from airflow.hooks import S3Hook
- self.hook = S3Hook(remote_conn_id)
- except:
- self.hook = None
- logging.error(
- 'Could not create an S3Hook with connection id "{}". '
- 'Please make sure that airflow[s3] is installed and '
- 'the S3 connection exists.'.format(remote_conn_id))
-
- def read(self, remote_log_location, return_error=False):
- """
- Returns the log found at the remote_log_location. Returns '' if no
- logs are found or there is an error.
-
- :param remote_log_location: the log's location in remote storage
- :type remote_log_location: string (path)
- :param return_error: if True, returns a string error message if an
- error occurs. Otherwise returns '' when an error occurs.
- :type return_error: bool
- """
- if self.hook:
- try:
- s3_key = self.hook.get_key(remote_log_location)
- if s3_key:
- return s3_key.get_contents_as_string().decode()
- except:
- pass
-
- # raise/return error if we get here
- err = 'Could not read logs from {}'.format(remote_log_location)
- logging.error(err)
- return err if return_error else ''
-
-
- def write(self, log, remote_log_location, append=False):
- """
- Writes the log to the remote_log_location. Fails silently if no hook
- was created.
-
- :param log: the log to write to the remote_log_location
- :type log: string
- :param remote_log_location: the log's location in remote storage
- :type remote_log_location: string (path)
- :param append: if False, any existing log file is overwritten. If True,
- the new log is appended to any existing logs.
- :type append: bool
-
- """
- if self.hook:
-
- if append:
- old_log = self.read(remote_log_location)
- log = old_log + '\n' + log
- try:
- self.hook.load_string(
- log,
- key=remote_log_location,
- replace=True,
- encrypt=configuration.get('core', 'ENCRYPT_S3_LOGS'))
- return
- except:
- pass
-
- # raise/return error if we get here
- logging.error('Could not write logs to {}'.format(remote_log_location))
-
-
-class GCSLog(object):
- """
- Utility class for reading and writing logs in GCS.
- Requires either airflow[gcloud] or airflow[gcp_api] and
- setting the REMOTE_BASE_LOG_FOLDER and REMOTE_LOG_CONN_ID configuration
- options in airflow.cfg.
- """
- def __init__(self):
- """
- Attempt to create hook with airflow[gcloud] (and set
- use_gcloud = True), otherwise uses airflow[gcp_api]
- """
- remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID')
- self.use_gcloud = False
-
- try:
- from airflow.contrib.hooks import GCSHook
- self.hook = GCSHook(remote_conn_id)
- self.use_gcloud = True
- except:
- try:
- from airflow.contrib.hooks import GoogleCloudStorageHook
- self.hook = GoogleCloudStorageHook(remote_conn_id)
- except:
- self.hook = None
- logging.error(
- 'Could not create a GCSHook with connection id "{}". '
- 'Please make sure that either airflow[gcloud] or '
- 'airflow[gcp_api] is installed and the GCS connection '
- 'exists.'.format(remote_conn_id))
-
- def read(self, remote_log_location, return_error=True):
- """
- Returns the log found at the remote_log_location.
-
- :param remote_log_location: the log's location in remote storage
- :type remote_log_location: string (path)
- :param return_error: if True, returns a string error message if an
- error occurs. Otherwise returns '' when an error occurs.
- :type return_error: bool
- """
- if self.hook:
- try:
- if self.use_gcloud:
- gcs_blob = self.hook.get_blob(remote_log_location)
- if gcs_blob:
- return gcs_blob.download_as_string().decode()
- else:
- bkt, blob = remote_log_location.lstrip('gs:/').split('/', 1)
- return self.hook.download(bkt, blob).decode()
- except:
- pass
-
- # raise/return error if we get here
- err = 'Could not read logs from {}'.format(remote_log_location)
- logging.error(err)
- return err if return_error else ''
-
- def write(self, log, remote_log_location, append=False):
- """
- Writes the log to the remote_log_location. Fails silently if no hook
- was created.
-
- :param log: the log to write to the remote_log_location
- :type log: string
- :param remote_log_location: the log's location in remote storage
- :type remote_log_location: string (path)
- :param append: if False, any existing log file is overwritten. If True,
- the new log is appended to any existing logs.
- :type append: bool
-
- """
- if self.hook:
-
- if append:
- old_log = self.read(remote_log_location)
- log = old_log + '\n' + log
-
- try:
- if self.use_gcloud:
- self.hook.upload_from_string(
- log,
- blob=remote_log_location,
- replace=True)
- return
- else:
- bkt, blob = remote_log_location.lstrip('gs:/').split('/', 1)
- from tempfile import NamedTemporaryFile
- with NamedTemporaryFile(mode='w+') as tmpfile:
- tmpfile.write(log)
- self.hook.upload(bkt, blob, tmpfile.name)
- return
- except:
- pass
-
- # raise/return error if we get here
- logging.error('Could not write logs to {}'.format(remote_log_location))
diff --git a/airflow/utils/__init__.py b/airflow/utils/__init__.py
new file mode 100644
index 0000000000000..58f1db3855dde
--- /dev/null
+++ b/airflow/utils/__init__.py
@@ -0,0 +1,21 @@
+from __future__ import absolute_import
+
+import warnings
+
+from .decorators import apply_defaults as _apply_defaults
+
+
+def apply_defaults(func):
+ warnings.warn_explicit(
+ """
+ You are importing apply_defaults from airflow.utils which
+ will be deprecated in a future version.
+ Please use :
+
+ from airflow.utils.decorators import apply_defaults
+ """,
+ category=PendingDeprecationWarning,
+ filename=func.func_code.co_filename,
+ lineno=func.func_code.co_firstlineno + 1
+ )
+ return _apply_defaults(func)
diff --git a/airflow/ascii.py b/airflow/utils/asciiart.py
similarity index 99%
rename from airflow/ascii.py
rename to airflow/utils/asciiart.py
index 60e02482bae96..85036ca4388a9 100644
--- a/airflow/ascii.py
+++ b/airflow/utils/asciiart.py
@@ -39,4 +39,3 @@
(/ / // /|//||||\\ \ \ \ _)
-------------------------------------------------------------------------------
"""
-
diff --git a/airflow/utils/dates.py b/airflow/utils/dates.py
new file mode 100644
index 0000000000000..121b775a35def
--- /dev/null
+++ b/airflow/utils/dates.py
@@ -0,0 +1,167 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from datetime import datetime, date, timedelta
+from dateutil.relativedelta import relativedelta # for doctest
+import six
+
+from croniter import croniter
+
+
+cron_presets = {
+ '@hourly': '0 * * * *',
+ '@daily': '0 0 * * *',
+ '@weekly': '0 0 * * 0',
+ '@monthly': '0 0 1 * *',
+ '@yearly': '0 0 1 1 *',
+}
+
+
+def date_range(
+ start_date,
+ end_date=None,
+ num=None,
+ delta=None):
+ """
+ Get a set of dates as a list based on a start, end and delta, delta
+ can be something that can be added to ``datetime.datetime``
+ or a cron expression as a ``str``
+
+ :param start_date: anchor date to start the series from
+ :type start_date: datetime.datetime
+ :param end_date: right boundary for the date range
+ :type end_date: datetime.datetime
+ :param num: alternatively to end_date, you can specify the number of
+ number of entries you want in the range. This number can be negative,
+ output will always be sorted regardless
+ :type num: int
+
+ >>> date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta=timedelta(1))
+ [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)]
+ >>> date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta='0 0 * * *')
+ [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)]
+ >>> date_range(datetime(2016, 1, 1), datetime(2016, 3, 3), delta="0 0 0 * *")
+ [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 2, 1, 0, 0), datetime.datetime(2016, 3, 1, 0, 0)]
+ """
+ if not delta:
+ return []
+ if end_date and start_date > end_date:
+ raise Exception("Wait. start_date needs to be before end_date")
+ if end_date and num:
+ raise Exception("Wait. Either specify end_date OR num")
+ if not end_date and not num:
+ end_date = datetime.now()
+
+ delta_iscron = False
+ if isinstance(delta, six.string_types):
+ delta_iscron = True
+ cron = croniter(delta, start_date)
+ elif isinstance(delta, timedelta):
+ delta = abs(delta)
+ l = []
+ if end_date:
+ while start_date <= end_date:
+ l.append(start_date)
+ if delta_iscron:
+ start_date = cron.get_next(datetime)
+ else:
+ start_date += delta
+ else:
+ for i in range(abs(num)):
+ l.append(start_date)
+ if delta_iscron:
+ if num > 0:
+ start_date = cron.get_next(datetime)
+ else:
+ start_date = cron.get_prev(datetime)
+ else:
+ if num > 0:
+ start_date += delta
+ else:
+ start_date -= delta
+ return sorted(l)
+
+
+def round_time(dt, delta, start_date=datetime.min):
+ """
+ Returns the datetime of the form start_date + i * delta
+ which is closest to dt for any non-negative integer i.
+
+ Note that delta may be a datetime.timedelta or a dateutil.relativedelta
+
+ >>> round_time(datetime(2015, 1, 1, 6), timedelta(days=1))
+ datetime.datetime(2015, 1, 1, 0, 0)
+ >>> round_time(datetime(2015, 1, 2), relativedelta(months=1))
+ datetime.datetime(2015, 1, 1, 0, 0)
+ >>> round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
+ datetime.datetime(2015, 9, 16, 0, 0)
+ >>> round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
+ datetime.datetime(2015, 9, 15, 0, 0)
+ >>> round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
+ datetime.datetime(2015, 9, 14, 0, 0)
+ >>> round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
+ datetime.datetime(2015, 9, 14, 0, 0)
+ """
+
+ if isinstance(delta, six.string_types):
+ # It's cron based, so it's easy
+ cron = croniter(delta, start_date)
+ prev = cron.get_prev(datetime)
+ if prev == start_date:
+ return start_date
+ else:
+ return prev
+
+ # Ignore the microseconds of dt
+ dt -= timedelta(microseconds=dt.microsecond)
+
+ # We are looking for a datetime in the form start_date + i * delta
+ # which is as close as possible to dt. Since delta could be a relative
+ # delta we don't know it's exact length in seconds so we cannot rely on
+ # division to find i. Instead we employ a binary search algorithm, first
+ # finding an upper and lower limit and then disecting the interval until
+ # we have found the closest match.
+
+ # We first search an upper limit for i for which start_date + upper * delta
+ # exceeds dt.
+ upper = 1
+ while start_date + upper*delta < dt:
+ # To speed up finding an upper limit we grow this exponentially by a
+ # factor of 2
+ upper *= 2
+
+ # Since upper is the first value for which start_date + upper * delta
+ # exceeds dt, upper // 2 is below dt and therefore forms a lower limited
+ # for the i we are looking for
+ lower = upper // 2
+
+ # We now continue to intersect the interval between
+ # start_date + lower * delta and start_date + upper * delta
+ # until we find the closest value
+ while True:
+ # Invariant: start + lower * delta < dt <= start + upper * delta
+ # If start_date + (lower + 1)*delta exceeds dt, then either lower or
+ # lower+1 has to be the solution we are searching for
+ if start_date + (lower + 1)*delta >= dt:
+ # Check if start_date + (lower + 1)*delta or
+ # start_date + lower*delta is closer to dt and return the solution
+ if (
+ (start_date + (lower + 1) * delta) - dt <=
+ dt - (start_date + lower * delta)):
+ return start_date + (lower + 1)*delta
+ else:
+ return start_date + lower * delta
+
+ # We intersect the interval and either replace the lower or upper
+ # limit with the candidate
+ candidate = lower + (upper - lower) // 2
+ if start_date + candidate*delta >= dt:
+ upper = candidate
+ else:
+ lower = candidate
+
+ # in the special case when start_date > dt the search for upper will
+ # immediately stop for upper == 1 which results in lower = upper // 2 = 0
+ # and this function returns start_date.
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
new file mode 100644
index 0000000000000..b01a946fc1bb5
--- /dev/null
+++ b/airflow/utils/db.py
@@ -0,0 +1,209 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from functools import wraps
+import logging
+import os
+
+from alembic.config import Config
+from alembic import command
+from alembic.migration import MigrationContext
+
+from sqlalchemy import event, exc
+from sqlalchemy.pool import Pool
+
+from airflow import settings
+from airflow import configuration
+
+
+def provide_session(func):
+ """
+ Function decorator that provides a session if it isn't provided.
+ If you want to reuse a session or run the function as part of a
+ database transaction, you pass it to the function, if not this wrapper
+ will create one and close it for you.
+ """
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ needs_session = False
+ if 'session' not in kwargs:
+ needs_session = True
+ session = settings.Session()
+ kwargs['session'] = session
+ result = func(*args, **kwargs)
+ if needs_session:
+ session.expunge_all()
+ session.commit()
+ session.close()
+ return result
+ return wrapper
+
+
+def pessimistic_connection_handling():
+ @event.listens_for(Pool, "checkout")
+ def ping_connection(dbapi_connection, connection_record, connection_proxy):
+ '''
+ Disconnect Handling - Pessimistic, taken from:
+ http://docs.sqlalchemy.org/en/rel_0_9/core/pooling.html
+ '''
+ cursor = dbapi_connection.cursor()
+ try:
+ cursor.execute("SELECT 1")
+ except:
+ raise exc.DisconnectionError()
+ cursor.close()
+
+
+@provide_session
+def merge_conn(conn, session=None):
+ from airflow import models
+ C = models.Connection
+ if not session.query(C).filter(C.conn_id == conn.conn_id).first():
+ session.add(conn)
+ session.commit()
+
+
+def initdb():
+ session = settings.Session()
+
+ from airflow import models
+ upgradedb()
+
+ merge_conn(
+ models.Connection(
+ conn_id='airflow_db', conn_type='mysql',
+ host='localhost', login='root', password='',
+ schema='airflow'))
+ merge_conn(
+ models.Connection(
+ conn_id='airflow_ci', conn_type='mysql',
+ host='localhost', login='root',
+ schema='airflow_ci'))
+ merge_conn(
+ models.Connection(
+ conn_id='beeline_default', conn_type='beeline', port="10000",
+ host='localhost', extra="{\"use_beeline\": true, \"auth\": \"\"}",
+ schema='default'))
+ merge_conn(
+ models.Connection(
+ conn_id='bigquery_default', conn_type='bigquery'))
+ merge_conn(
+ models.Connection(
+ conn_id='local_mysql', conn_type='mysql',
+ host='localhost', login='airflow', password='airflow',
+ schema='airflow'))
+ merge_conn(
+ models.Connection(
+ conn_id='presto_default', conn_type='presto',
+ host='localhost',
+ schema='hive', port=3400))
+ merge_conn(
+ models.Connection(
+ conn_id='hive_cli_default', conn_type='hive_cli',
+ schema='default',))
+ merge_conn(
+ models.Connection(
+ conn_id='hiveserver2_default', conn_type='hiveserver2',
+ host='localhost',
+ schema='default', port=10000))
+ merge_conn(
+ models.Connection(
+ conn_id='metastore_default', conn_type='hive_metastore',
+ host='localhost', extra="{\"authMechanism\": \"PLAIN\"}",
+ port=9083))
+ merge_conn(
+ models.Connection(
+ conn_id='mysql_default', conn_type='mysql',
+ login='root',
+ host='localhost'))
+ merge_conn(
+ models.Connection(
+ conn_id='postgres_default', conn_type='postgres',
+ login='postgres',
+ schema='airflow',
+ host='localhost'))
+ merge_conn(
+ models.Connection(
+ conn_id='sqlite_default', conn_type='sqlite',
+ host='/tmp/sqlite_default.db'))
+ merge_conn(
+ models.Connection(
+ conn_id='http_default', conn_type='http',
+ host='https://www.google.com/'))
+ merge_conn(
+ models.Connection(
+ conn_id='mssql_default', conn_type='mssql',
+ host='localhost', port=1433))
+ merge_conn(
+ models.Connection(
+ conn_id='vertica_default', conn_type='vertica',
+ host='localhost', port=5433))
+ merge_conn(
+ models.Connection(
+ conn_id='webhdfs_default', conn_type='hdfs',
+ host='localhost', port=50070))
+ merge_conn(
+ models.Connection(
+ conn_id='ssh_default', conn_type='ssh',
+ host='localhost'))
+
+ # Known event types
+ KET = models.KnownEventType
+ if not session.query(KET).filter(KET.know_event_type == 'Holiday').first():
+ session.add(KET(know_event_type='Holiday'))
+ if not session.query(KET).filter(KET.know_event_type == 'Outage').first():
+ session.add(KET(know_event_type='Outage'))
+ if not session.query(KET).filter(
+ KET.know_event_type == 'Natural Disaster').first():
+ session.add(KET(know_event_type='Natural Disaster'))
+ if not session.query(KET).filter(
+ KET.know_event_type == 'Marketing Campaign').first():
+ session.add(KET(know_event_type='Marketing Campaign'))
+ session.commit()
+
+ models.DagBag(sync_to_db=True)
+
+ Chart = models.Chart
+ chart_label = "Airflow task instance by type"
+ chart = session.query(Chart).filter(Chart.label == chart_label).first()
+ if not chart:
+ chart = Chart(
+ label=chart_label,
+ conn_id='airflow_db',
+ chart_type='bar',
+ x_is_date=False,
+ sql=(
+ "SELECT state, COUNT(1) as number "
+ "FROM task_instance "
+ "WHERE dag_id LIKE 'example%' "
+ "GROUP BY state"),
+ )
+ session.add(chart)
+
+
+def upgradedb():
+ logging.info("Creating tables")
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ package_dir = os.path.normpath(os.path.join(current_dir, '..'))
+ directory = os.path.join(package_dir, 'migrations')
+ config = Config(os.path.join(package_dir, 'alembic.ini'))
+ config.set_main_option('script_location', directory)
+ config.set_main_option('sqlalchemy.url',
+ configuration.get('core', 'SQL_ALCHEMY_CONN'))
+ command.upgrade(config, 'heads')
+
+
+def resetdb():
+ '''
+ Clear out the database
+ '''
+ from airflow import models
+
+ logging.info("Dropping tables that exist")
+ models.Base.metadata.drop_all(settings.engine)
+ mc = MigrationContext.configure(settings.engine)
+ if mc._version.exists(settings.engine):
+ mc._version.drop(settings.engine)
+ initdb()
diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py
new file mode 100644
index 0000000000000..0f869073a218d
--- /dev/null
+++ b/airflow/utils/decorators.py
@@ -0,0 +1,67 @@
+import inspect
+import os
+
+from copy import copy
+from functools import wraps
+
+from airflow.exceptions import AirflowException
+
+
+def apply_defaults(func):
+ """
+ Function decorator that Looks for an argument named "default_args", and
+ fills the unspecified arguments from it.
+
+ Since python2.* isn't clear about which arguments are missing when
+ calling a function, and that this can be quite confusing with multi-level
+ inheritance and argument defaults, this decorator also alerts with
+ specific information about the missing arguments.
+ """
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if len(args) > 1:
+ raise AirflowException(
+ "Use keyword arguments when initializing operators")
+ dag_args = {}
+ dag_params = {}
+ if 'dag' in kwargs and kwargs['dag']:
+ dag = kwargs['dag']
+ dag_args = copy(dag.default_args) or {}
+ dag_params = copy(dag.params) or {}
+
+ params = {}
+ if 'params' in kwargs:
+ params = kwargs['params']
+ dag_params.update(params)
+
+ default_args = {}
+ if 'default_args' in kwargs:
+ default_args = kwargs['default_args']
+ if 'params' in default_args:
+ dag_params.update(default_args['params'])
+ del default_args['params']
+
+ dag_args.update(default_args)
+ default_args = dag_args
+ arg_spec = inspect.getargspec(func)
+ num_defaults = len(arg_spec.defaults) if arg_spec.defaults else 0
+ non_optional_args = arg_spec.args[:-num_defaults]
+ if 'self' in non_optional_args:
+ non_optional_args.remove('self')
+ for arg in func.__code__.co_varnames:
+ if arg in default_args and arg not in kwargs:
+ kwargs[arg] = default_args[arg]
+ missing_args = list(set(non_optional_args) - set(kwargs))
+ if missing_args:
+ msg = "Argument {0} is required".format(missing_args)
+ raise AirflowException(msg)
+
+ kwargs['params'] = dag_params
+
+ result = func(*args, **kwargs)
+ return result
+ return wrapper
+
+if 'BUILDING_AIRFLOW_DOCS' in os.environ:
+ # Monkey patch hook to get good function headers while building docs
+ apply_defaults = lambda x: x
diff --git a/airflow/utils/email.py b/airflow/utils/email.py
new file mode 100644
index 0000000000000..98f5eb2a625d4
--- /dev/null
+++ b/airflow/utils/email.py
@@ -0,0 +1,84 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from builtins import str
+from past.builtins import basestring
+
+import importlib
+import logging
+import os
+import smtplib
+
+from email.mime.text import MIMEText
+from email.mime.multipart import MIMEMultipart
+from email.mime.application import MIMEApplication
+from email.utils import formatdate
+
+from airflow import configuration
+
+
+def send_email(to, subject, html_content, files=None, dryrun=False):
+ """
+ Send email using backend specified in EMAIL_BACKEND.
+ """
+ path, attr = configuration.get('email', 'EMAIL_BACKEND').rsplit('.', 1)
+ module = importlib.import_module(path)
+ backend = getattr(module, attr)
+ return backend(to, subject, html_content, files=files, dryrun=dryrun)
+
+
+def send_email_smtp(to, subject, html_content, files=None, dryrun=False):
+ """
+ Send an email with html content
+
+ >>> send_email('test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True)
+ """
+ SMTP_MAIL_FROM = configuration.get('smtp', 'SMTP_MAIL_FROM')
+
+ if isinstance(to, basestring):
+ if ',' in to:
+ to = to.split(',')
+ elif ';' in to:
+ to = to.split(';')
+ else:
+ to = [to]
+
+ msg = MIMEMultipart('alternative')
+ msg['Subject'] = subject
+ msg['From'] = SMTP_MAIL_FROM
+ msg['To'] = ", ".join(to)
+ msg['Date'] = formatdate(localtime=True)
+ mime_text = MIMEText(html_content, 'html')
+ msg.attach(mime_text)
+
+ for fname in files or []:
+ basename = os.path.basename(fname)
+ with open(fname, "rb") as f:
+ msg.attach(MIMEApplication(
+ f.read(),
+ Content_Disposition='attachment; filename="%s"' % basename,
+ Name=basename
+ ))
+
+ send_MIME_email(SMTP_MAIL_FROM, to, msg, dryrun)
+
+
+def send_MIME_email(e_from, e_to, mime_msg, dryrun=False):
+ SMTP_HOST = configuration.get('smtp', 'SMTP_HOST')
+ SMTP_PORT = configuration.getint('smtp', 'SMTP_PORT')
+ SMTP_USER = configuration.get('smtp', 'SMTP_USER')
+ SMTP_PASSWORD = configuration.get('smtp', 'SMTP_PASSWORD')
+ SMTP_STARTTLS = configuration.getboolean('smtp', 'SMTP_STARTTLS')
+ SMTP_SSL = configuration.getboolean('smtp', 'SMTP_SSL')
+
+ if not dryrun:
+ s = smtplib.SMTP_SSL(SMTP_HOST, SMTP_PORT) if SMTP_SSL else smtplib.SMTP(SMTP_HOST, SMTP_PORT)
+ if SMTP_STARTTLS:
+ s.starttls()
+ if SMTP_USER and SMTP_PASSWORD:
+ s.login(SMTP_USER, SMTP_PASSWORD)
+ logging.info("Sent an alert email to " + str(e_to))
+ s.sendmail(e_from, e_to, mime_msg.as_string())
+ s.quit()
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
new file mode 100644
index 0000000000000..27a4233627d4c
--- /dev/null
+++ b/airflow/utils/file.py
@@ -0,0 +1,22 @@
+from __future__ import absolute_import
+from __future__ import unicode_literals
+
+import errno
+import shutil
+from tempfile import mkdtemp
+
+from contextlib import contextmanager
+
+
+@contextmanager
+def TemporaryDirectory(suffix='', prefix=None, dir=None):
+ name = mkdtemp(suffix=suffix, prefix=prefix, dir=dir)
+ try:
+ yield name
+ finally:
+ try:
+ shutil.rmtree(name)
+ except OSError as e:
+ # ENOENT - no such file or directory
+ if e.errno != errno.ENOENT:
+ raise e
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
new file mode 100644
index 0000000000000..bd94b72dfcd5a
--- /dev/null
+++ b/airflow/utils/helpers.py
@@ -0,0 +1,133 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from builtins import input
+from past.builtins import basestring
+from datetime import datetime
+import imp
+import logging
+import os
+import re
+
+from airflow.exceptions import AirflowException
+
+
+def validate_key(k, max_length=250):
+ if not isinstance(k, basestring):
+ raise TypeError("The key has to be a string")
+ elif len(k) > max_length:
+ raise AirflowException(
+ "The key has to be less than {0} characters".format(max_length))
+ elif not re.match(r'^[A-Za-z0-9_\-\.]+$', k):
+ raise AirflowException(
+ "The key ({k}) has to be made of alphanumeric characters, dashes, "
+ "dots and underscores exclusively".format(**locals()))
+ else:
+ return True
+
+
+def alchemy_to_dict(obj):
+ """
+ Transforms a SQLAlchemy model instance into a dictionary
+ """
+ if not obj:
+ return None
+ d = {}
+ for c in obj.__table__.columns:
+ value = getattr(obj, c.name)
+ if type(value) == datetime:
+ value = value.isoformat()
+ d[c.name] = value
+ return d
+
+
+def ask_yesno(question):
+ yes = set(['yes', 'y'])
+ no = set(['no', 'n'])
+
+ done = False
+ print(question)
+ while not done:
+ choice = input().lower()
+ if choice in yes:
+ return True
+ elif choice in no:
+ return False
+ else:
+ print("Please respond by yes or no.")
+
+
+def import_module_attrs(parent_module_globals, module_attrs_dict):
+ '''
+ Attempts to import a set of modules and specified attributes in the
+ form of a dictionary. The attributes are copied in the parent module's
+ namespace. The function returns a list of attributes names that can be
+ affected to __all__.
+
+ This is used in the context of ``operators`` and ``hooks`` and
+ silence the import errors for when libraries are missing. It makes
+ for a clean package abstracting the underlying modules and only
+ brings functional operators to those namespaces.
+ '''
+ imported_attrs = []
+ for mod, attrs in list(module_attrs_dict.items()):
+ try:
+ path = os.path.realpath(parent_module_globals['__file__'])
+ folder = os.path.dirname(path)
+ f, filename, description = imp.find_module(mod, [folder])
+ module = imp.load_module(mod, f, filename, description)
+ for attr in attrs:
+ parent_module_globals[attr] = getattr(module, attr)
+ imported_attrs += [attr]
+ except Exception as err:
+ logging.debug("Error importing module {mod}: {err}".format(
+ mod=mod, err=err))
+ return imported_attrs
+
+
+def is_in(obj, l):
+ """
+ Checks whether an object is one of the item in the list.
+ This is different from ``in`` because ``in`` uses __cmp__ when
+ present. Here we change based on the object itself
+ """
+ for item in l:
+ if item is obj:
+ return True
+ return False
+
+
+def is_container(obj):
+ """
+ Test if an object is a container (iterable) but not a string
+ """
+ return hasattr(obj, '__iter__') and not isinstance(obj, basestring)
+
+
+def as_tuple(obj):
+ """
+ If obj is a container, returns obj as a tuple.
+ Otherwise, returns a tuple containing obj.
+ """
+ if is_container(obj):
+ return tuple(obj)
+ else:
+ return tuple([obj])
+
+
+def chain(*tasks):
+ """
+ Given a number of tasks, builds a dependency chain.
+
+ chain(task_1, task_2, task_3, task_4)
+
+ is equivalent to
+
+ task_1.set_downstream(task_2)
+ task_2.set_downstream(task_3)
+ task_3.set_downstream(task_4)
+ """
+ for up_task, down_task in zip(tasks[:-1], tasks[1:]):
+ up_task.set_downstream(down_task)
diff --git a/airflow/utils/json.py b/airflow/utils/json.py
new file mode 100644
index 0000000000000..a6aaf89f81aee
--- /dev/null
+++ b/airflow/utils/json.py
@@ -0,0 +1,40 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from datetime import datetime, date
+import json
+import numpy as np
+
+
+# Dates and JSON encoding/deconding
+
+def json_ser(obj):
+ """
+ json serializer that deals with dates
+ usage: json.dumps(object, default=utils.json.json_ser)
+ """
+ if isinstance(obj, (datetime, date)):
+ return obj.isoformat()
+
+
+class AirflowJsonEncoder(json.JSONEncoder):
+ def default(self, obj):
+ # convert dates and numpy objects in a json serializable format
+ if isinstance(obj, datetime):
+ return obj.strftime('%Y-%m-%dT%H:%M:%SZ')
+ elif isinstance(obj, date):
+ return obj.strftime('%Y-%m-%d')
+ elif type(obj) in [np.int_, np.intc, np.intp, np.int8, np.int16,
+ np.int32, np.int64, np.uint8, np.uint16,
+ np.uint32, np.uint64]:
+ return int(obj)
+ elif type(obj) in [np.bool_]:
+ return bool(obj)
+ elif type(obj) in [np.float_, np.float16, np.float32, np.float64,
+ np.complex_, np.complex64, np.complex128]:
+ return float(obj)
+
+ # Let the base class default method raise the TypeError
+ return json.JSONEncoder.default(self, obj)
diff --git a/airflow/utils/logging.py b/airflow/utils/logging.py
new file mode 100644
index 0000000000000..1231a494e6558
--- /dev/null
+++ b/airflow/utils/logging.py
@@ -0,0 +1,123 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from builtins import object
+
+import logging
+
+from airflow import configuration
+
+
+class LoggingMixin(object):
+ """
+ Convenience super-class to have a logger configured with the class name
+ """
+
+ @property
+ def logger(self):
+ try:
+ return self._logger
+ except AttributeError:
+ self._logger = logging.root.getChild(self.__class__.__module__ + '.' + self.__class__.__name__)
+ return self._logger
+
+
+class GCSLog(object):
+ """
+ Utility class for reading and writing logs in GCS.
+ Requires either airflow[gcloud] or airflow[gcp_api] and
+ setting the REMOTE_BASE_LOG_FOLDER and REMOTE_LOG_CONN_ID configuration
+ options in airflow.cfg.
+ """
+ def __init__(self):
+ """
+ Attempt to create hook with airflow[gcloud] (and set
+ use_gcloud = True), otherwise uses airflow[gcp_api]
+ """
+ remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID')
+ self.use_gcloud = False
+
+ try:
+ from airflow.contrib.hooks import GCSHook
+ self.hook = GCSHook(remote_conn_id)
+ self.use_gcloud = True
+ except:
+ try:
+ from airflow.contrib.hooks import GoogleCloudStorageHook
+ self.hook = GoogleCloudStorageHook(remote_conn_id)
+ except:
+ self.hook = None
+ logging.error(
+ 'Could not create a GCSHook with connection id "{}". '
+ 'Please make sure that either airflow[gcloud] or '
+ 'airflow[gcp_api] is installed and the GCS connection '
+ 'exists.'.format(remote_conn_id))
+
+ def read(self, remote_log_location, return_error=True):
+ """
+ Returns the log found at the remote_log_location.
+
+ :param remote_log_location: the log's location in remote storage
+ :type remote_log_location: string (path)
+ :param return_error: if True, returns a string error message if an
+ error occurs. Otherwise returns '' when an error occurs.
+ :type return_error: bool
+ """
+ if self.hook:
+ try:
+ if self.use_gcloud:
+ gcs_blob = self.hook.get_blob(remote_log_location)
+ if gcs_blob:
+ return gcs_blob.download_as_string().decode()
+ else:
+ bkt, blob = remote_log_location.lstrip('gs:/').split('/', 1)
+ return self.hook.download(bkt, blob).decode()
+ except:
+ pass
+
+ # raise/return error if we get here
+ err = 'Could not read logs from {}'.format(remote_log_location)
+ logging.error(err)
+ return err if return_error else ''
+
+ def write(self, log, remote_log_location, append=False):
+ """
+ Writes the log to the remote_log_location. Fails silently if no hook
+ was created.
+
+ :param log: the log to write to the remote_log_location
+ :type log: string
+ :param remote_log_location: the log's location in remote storage
+ :type remote_log_location: string (path)
+ :param append: if False, any existing log file is overwritten. If True,
+ the new log is appended to any existing logs.
+ :type append: bool
+
+ """
+ if self.hook:
+
+ if append:
+ old_log = self.read(remote_log_location)
+ log = old_log + '\n' + log
+
+ try:
+ if self.use_gcloud:
+ self.hook.upload_from_string(
+ log,
+ blob=remote_log_location,
+ replace=True)
+ return
+ else:
+ bkt, blob = remote_log_location.lstrip('gs:/').split('/', 1)
+ from tempfile import NamedTemporaryFile
+ with NamedTemporaryFile(mode='w+') as tmpfile:
+ tmpfile.write(log)
+ self.hook.upload(bkt, blob, tmpfile.name)
+ return
+ except:
+ pass
+
+ # raise/return error if we get here
+ logging.error('Could not write logs to {}'.format(remote_log_location))
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
new file mode 100644
index 0000000000000..9657375c9681c
--- /dev/null
+++ b/airflow/utils/state.py
@@ -0,0 +1,50 @@
+from __future__ import unicode_literals
+
+from builtins import object
+
+
+class State(object):
+ """
+ Static class with task instance states constants and color method to
+ avoid hardcoding.
+ """
+ QUEUED = "queued"
+ RUNNING = "running"
+ SUCCESS = "success"
+ SHUTDOWN = "shutdown" # External request to shut down
+ FAILED = "failed"
+ UP_FOR_RETRY = "up_for_retry"
+ UPSTREAM_FAILED = "upstream_failed"
+ SKIPPED = "skipped"
+
+ state_color = {
+ QUEUED: 'gray',
+ RUNNING: 'lime',
+ SUCCESS: 'green',
+ SHUTDOWN: 'blue',
+ FAILED: 'red',
+ UP_FOR_RETRY: 'gold',
+ UPSTREAM_FAILED: 'orange',
+ SKIPPED: 'pink',
+ }
+
+ @classmethod
+ def color(cls, state):
+ if state in cls.state_color:
+ return cls.state_color[state]
+ else:
+ return 'white'
+
+ @classmethod
+ def color_fg(cls, state):
+ color = cls.color(state)
+ if color in ['green', 'red']:
+ return 'white'
+ else:
+ return 'black'
+
+ @classmethod
+ def runnable(cls):
+ return [
+ None, cls.FAILED, cls.UP_FOR_RETRY, cls.UPSTREAM_FAILED,
+ cls.SKIPPED, cls.QUEUED]
diff --git a/airflow/utils/timeout.py b/airflow/utils/timeout.py
new file mode 100644
index 0000000000000..eaf15b4fb0b41
--- /dev/null
+++ b/airflow/utils/timeout.py
@@ -0,0 +1,39 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import logging
+import signal
+
+from builtins import object
+
+from airflow.exceptions import AirflowTaskTimeout
+
+
+class timeout(object):
+ """
+ To be used in a ``with`` block and timeout its content.
+ """
+ def __init__(self, seconds=1, error_message='Timeout'):
+ self.seconds = seconds
+ self.error_message = error_message
+
+ def handle_timeout(self, signum, frame):
+ logging.error("Process timed out")
+ raise AirflowTaskTimeout(self.error_message)
+
+ def __enter__(self):
+ try:
+ signal.signal(signal.SIGALRM, self.handle_timeout)
+ signal.alarm(self.seconds)
+ except ValueError as e:
+ logging.warning("timeout can't be used in the current context")
+ logging.exception(e)
+
+ def __exit__(self, type, value, traceback):
+ try:
+ signal.alarm(0)
+ except ValueError as e:
+ logging.warning("timeout can't be used in the current context")
+ logging.exception(e)
diff --git a/airflow/utils/trigger_rule.py b/airflow/utils/trigger_rule.py
new file mode 100644
index 0000000000000..bfa05419ece10
--- /dev/null
+++ b/airflow/utils/trigger_rule.py
@@ -0,0 +1,22 @@
+from __future__ import unicode_literals
+
+from builtins import object
+
+
+class TriggerRule(object):
+ ALL_SUCCESS = 'all_success'
+ ALL_FAILED = 'all_failed'
+ ALL_DONE = 'all_done'
+ ONE_SUCCESS = 'one_success'
+ ONE_FAILED = 'one_failed'
+ DUMMY = 'dummy'
+
+ @classmethod
+ def is_valid(cls, trigger_rule):
+ return trigger_rule in cls.all_triggers()
+
+ @classmethod
+ def all_triggers(cls):
+ return [getattr(cls, attr)
+ for attr in dir(cls)
+ if not attr.startswith("__") and not callable(getattr(cls, attr))]
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 03abf83a6ef3c..69dddf1faa5fa 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -8,14 +8,15 @@
import gzip
import dateutil.parser as dateparser
import json
-import os
from flask import after_this_request, request, Response
from flask_login import current_user
from jinja2 import Template
import wtforms
from wtforms.compat import text_type
-from airflow import configuration, models, settings, utils
+from airflow import configuration, models, settings
+from airflow.utils.json import AirflowJsonEncoder
+from airflow.utils.email import send_email
AUTHENTICATE = configuration.getboolean('webserver', 'AUTHENTICATE')
@@ -147,7 +148,7 @@ def wrapper(*args, **kwargs):
''').render(**locals())
if task.email:
- utils.send_email(task.email, subject, content)
+ send_email(task.email, subject, content)
"""
return f(*args, **kwargs)
return wrapper
@@ -159,7 +160,7 @@ def json_response(obj):
"""
return Response(
response=json.dumps(
- obj, indent=4, cls=utils.AirflowJsonEncoder),
+ obj, indent=4, cls=AirflowJsonEncoder),
status=200,
mimetype="application/json")
diff --git a/airflow/www/views.py b/airflow/www/views.py
index cbb2db39ee42f..d7864ec746f3f 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -37,15 +37,18 @@
from pygments.formatters import HtmlFormatter
import airflow
-from airflow import models
-from airflow.settings import Session
from airflow import configuration as conf
-from airflow import utils
-from airflow.utils import AirflowException
-from airflow.www import utils as wwwutils
+from airflow import models
from airflow import settings
-from airflow.models import State
+from airflow.exceptions import AirflowException
+from airflow.settings import Session
+from airflow.utils.json import json_ser
+from airflow.utils.state import State
+from airflow.utils.db import provide_session
+from airflow.utils.helpers import alchemy_to_dict
+from airflow.utils import logging as log_utils
+from airflow.www import utils as wwwutils
from airflow.www.forms import DateTimeForm, DateTimeWithNumRunsForm
QUERY_LIMIT = 100000
@@ -637,7 +640,7 @@ def dag_details(self):
)
return self.render(
'airflow/dag_details.html',
- dag=dag, title=title, states=states, State=utils.State)
+ dag=dag, title=title, states=states, State=State)
@current_app.errorhandler(404)
def circles(self):
@@ -646,7 +649,7 @@ def circles(self):
@current_app.errorhandler(500)
def show_traceback(self):
- from airflow import ascii as ascii_
+ from airflow.utils import asciiart as ascii_
return render_template(
'airflow/traceback.html',
hostname=socket.gethostname(),
@@ -803,11 +806,11 @@ def log(self):
# S3
if remote_log.startswith('s3:/'):
- log += utils.S3Log().read(remote_log, return_error=True)
+ log += log_utils.S3Log().read(remote_log, return_error=True)
# GCS
elif remote_log.startswith('gs:/'):
- log += utils.GCSLog().read(remote_log, return_error=True)
+ log += log_utils.GCSLog().read(remote_log, return_error=True)
# unsupported
elif remote_log:
@@ -1138,7 +1141,7 @@ def tree(self):
.all()
)
dag_runs = {
- dr.execution_date: utils.alchemy_to_dict(dr) for dr in dag_runs}
+ dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs}
tis = dag.get_task_instances(
session, start_date=min_date, end_date=base_date)
@@ -1146,7 +1149,7 @@ def tree(self):
max_date = max([ti.execution_date for ti in tis]) if dates else None
task_instances = {}
for ti in tis:
- tid = utils.alchemy_to_dict(ti)
+ tid = alchemy_to_dict(ti)
dr = dag_runs.get(ti.execution_date)
tid['external_trigger'] = dr['external_trigger'] if dr else False
task_instances[(ti.task_id, ti.execution_date)] = tid
@@ -1201,7 +1204,7 @@ def recurse_nodes(task, visited):
for d in dates],
}
- data = json.dumps(data, indent=4, default=utils.json_ser)
+ data = json.dumps(data, indent=4, default=json_ser)
session.commit()
session.close()
@@ -1294,7 +1297,7 @@ class GraphForm(Form):
data={'execution_date': dttm.isoformat(), 'arrange': arrange})
task_instances = {
- ti.task_id: utils.alchemy_to_dict(ti)
+ ti.task_id: alchemy_to_dict(ti)
for ti in dag.get_task_instances(session, dttm, dttm)}
tasks = {
t.task_id: {
@@ -1593,7 +1596,7 @@ def task_instances(self):
return ("Error: Invalid execution_date")
task_instances = {
- ti.task_id: utils.alchemy_to_dict(ti)
+ ti.task_id: alchemy_to_dict(ti)
for ti in dag.get_task_instances(session, dttm, dttm)}
return json.dumps(task_instances)
@@ -1979,7 +1982,7 @@ def action_set_failed(self, ids):
def action_set_success(self, ids):
self.set_dagrun_state(ids, State.SUCCESS)
- @utils.provide_session
+ @provide_session
def set_dagrun_state(self, ids, target_state, session=None):
try:
DR = models.DagRun
@@ -2058,7 +2061,7 @@ def action_set_success(self, ids):
def action_set_retry(self, ids):
self.set_task_instance_state(ids, State.UP_FOR_RETRY)
- @utils.provide_session
+ @provide_session
def set_task_instance_state(self, ids, target_state, session=None):
try:
TI = models.TaskInstance
diff --git a/tests/core.py b/tests/core.py
index 5cc737dcf51fe..4a758009e0200 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -2,7 +2,6 @@
import doctest
import json
-import logging
import os
import re
import unittest
@@ -22,14 +21,17 @@
from airflow.models import Variable
configuration.test_mode()
-from airflow import jobs, models, DAG, utils, operators, hooks, macros, settings
+from airflow import jobs, models, DAG, utils, operators, hooks, macros, settings, exceptions
from airflow.hooks import BaseHook
from airflow.bin import cli
from airflow.www import app as application
from airflow.settings import Session
-from airflow.utils import LoggingMixin, round_time
+from airflow.utils.state import State
+from airflow.utils.dates import round_time
+from airflow.utils.logging import LoggingMixin
+from airflow.utils import email as email_utils
from lxml import html
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
from airflow.configuration import AirflowConfigException
from airflow.minihivecluster import MiniHiveCluster
@@ -121,7 +123,7 @@ def test_schedule_dag_no_previous_runs(self):
assert dag_run.execution_date == datetime(2015, 1, 2, 0, 0), (
'dag_run.execution_date did not match expectation: {0}'
.format(dag_run.execution_date))
- assert dag_run.state == models.State.RUNNING
+ assert dag_run.state == State.RUNNING
assert dag_run.external_trigger == False
def test_schedule_dag_fake_scheduled_previous(self):
@@ -141,7 +143,7 @@ def test_schedule_dag_fake_scheduled_previous(self):
dag_id=dag.dag_id,
run_id=models.DagRun.id_for_date(DEFAULT_DATE),
execution_date=DEFAULT_DATE,
- state=utils.State.SUCCESS,
+ state=State.SUCCESS,
external_trigger=True)
settings.Session().add(trigger)
settings.Session().commit()
@@ -153,7 +155,7 @@ def test_schedule_dag_fake_scheduled_previous(self):
assert dag_run.execution_date == DEFAULT_DATE + delta, (
'dag_run.execution_date did not match expectation: {0}'
.format(dag_run.execution_date))
- assert dag_run.state == models.State.RUNNING
+ assert dag_run.state == State.RUNNING
assert dag_run.external_trigger == False
def test_schedule_dag_once(self):
@@ -239,7 +241,7 @@ def test_schedule_dag_no_end_date_up_to_today_only(self):
dag_runs.append(dag_run)
# Mark the DagRun as complete
- dag_run.state = utils.State.SUCCESS
+ dag_run.state = State.SUCCESS
session.merge(dag_run)
session.commit()
@@ -461,7 +463,7 @@ def test_timeout(self):
python_callable=lambda: sleep(5),
dag=self.dag)
self.assertRaises(
- utils.AirflowTaskTimeout,
+ exceptions.AirflowTaskTimeout,
t.run,
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
@@ -1255,7 +1257,7 @@ def test_poke(self):
class ConnectionTest(unittest.TestCase):
def setUp(self):
configuration.test_mode()
- utils.initdb()
+ utils.db.initdb()
os.environ['AIRFLOW_CONN_TEST_URI'] = (
'postgres://username:password@ec2.compute.com:5432/the_database')
os.environ['AIRFLOW_CONN_TEST_URI_NO_CREDS'] = (
@@ -1396,16 +1398,16 @@ class EmailTest(unittest.TestCase):
def setUp(self):
configuration.remove_option('email', 'EMAIL_BACKEND')
- @mock.patch('airflow.utils.send_email_smtp')
+ @mock.patch('email_utils.send_email_smtp')
def test_default_backend(self, mock_send_email):
- res = utils.send_email('to', 'subject', 'content')
+ res = email_utils.send_email('to', 'subject', 'content')
mock_send_email.assert_called_with('to', 'subject', 'content', files=None, dryrun=False)
assert res == mock_send_email.return_value
- @mock.patch('airflow.utils.send_email_smtp')
+ @mock.patch('email_utils.send_email_smtp')
def test_custom_backend(self, mock_send_email):
configuration.set('email', 'EMAIL_BACKEND', 'tests.core.send_email_test')
- utils.send_email('to', 'subject', 'content')
+ email_utils.send_email('to', 'subject', 'content')
send_email_test.assert_called_with('to', 'subject', 'content', files=None, dryrun=False)
assert not mock_send_email.called
@@ -1414,12 +1416,12 @@ class EmailSmtpTest(unittest.TestCase):
def setUp(self):
configuration.set('smtp', 'SMTP_SSL', 'False')
- @mock.patch('airflow.utils.send_MIME_email')
+ @mock.patch('email_utils.send_MIME_email')
def test_send_smtp(self, mock_send_mime):
attachment = tempfile.NamedTemporaryFile()
attachment.write(b'attachment')
attachment.seek(0)
- utils.send_email_smtp('to', 'subject', 'content', files=[attachment.name])
+ email_utils.send_email_smtp('to', 'subject', 'content', files=[attachment.name])
assert mock_send_mime.called
call_args = mock_send_mime.call_args[0]
assert call_args[0] == configuration.get('smtp', 'SMTP_MAIL_FROM')
@@ -1437,7 +1439,7 @@ def test_send_mime(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
msg = MIMEMultipart()
- utils.send_MIME_email('from', 'to', msg, dryrun=False)
+ email_utils.send_MIME_email('from', 'to', msg, dryrun=False)
mock_smtp.assert_called_with(
configuration.get('smtp', 'SMTP_HOST'),
configuration.getint('smtp', 'SMTP_PORT'),
@@ -1456,7 +1458,7 @@ def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
configuration.set('smtp', 'SMTP_SSL', 'True')
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
- utils.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False)
+ email_utils.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False)
assert not mock_smtp.called
mock_smtp_ssl.assert_called_with(
configuration.get('smtp', 'SMTP_HOST'),
@@ -1466,7 +1468,7 @@ def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl):
- utils.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=True)
+ email_utils.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=True)
assert not mock_smtp.called
assert not mock_smtp_ssl.called
diff --git a/tests/operators/docker_operator.py b/tests/operators/docker_operator.py
index 15dca9d7b98cf..4f9004c6292ea 100644
--- a/tests/operators/docker_operator.py
+++ b/tests/operators/docker_operator.py
@@ -3,7 +3,7 @@
from airflow.operators.docker_operator import DockerOperator
from docker.client import Client
-from airflow.utils import AirflowException
+from airflow.exceptions import AirflowException
try:
from unittest import mock
@@ -16,7 +16,7 @@
class DockerOperatorTestCase(unittest.TestCase):
@unittest.skipIf(mock is None, 'mock package not present')
- @mock.patch('airflow.utils.mkdtemp')
+ @mock.patch('airflow.utils.file.mkdtemp')
@mock.patch('airflow.operators.docker_operator.Client')
def test_execute(self, client_class_mock, mkdtemp_mock):
host_config = mock.Mock()