Skip to content

Commit

Permalink
Fixed dask executor and tests (#22027)
Browse files Browse the repository at this point in the history
 Fixed dask executor and tests, distributed package does not ship with tests folder and the certificates, added certificates to certs folder

(cherry picked from commit d3c168c)
  • Loading branch information
subkanthi authored and ephraimbuddy committed Mar 26, 2022
1 parent 1613e03 commit e164cf9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
2 changes: 1 addition & 1 deletion airflow/executors/dask_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def airflow_run():
raise AirflowException(f"Attempted to submit task to an unavailable queue: '{queue}'")
resources = {queue: 1}

future = self.client.submit(airflow_run, pure=False, resources=resources)
future = self.client.submit(subprocess.check_call, command, pure=False, resources=resources)
self.futures[future] = key # type: ignore

def _process_future(self, future: Future) -> None:
Expand Down
34 changes: 22 additions & 12 deletions tests/executors/test_dask_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,41 @@
from unittest import mock

import pytest
from distributed import LocalCluster

from airflow.exceptions import AirflowException
from airflow.executors.dask_executor import DaskExecutor
from airflow.jobs.backfill_job import BackfillJob
from airflow.models import DagBag
from airflow.utils import timezone
from tests.test_utils.config import conf_vars

try:
from distributed import LocalCluster

# utility functions imported from the dask testing suite to instantiate a test
# cluster for tls tests
from distributed import tests # noqa
from distributed.utils_test import cluster as dask_testing_cluster, get_cert, tls_security

from airflow.executors.dask_executor import DaskExecutor

skip_tls_tests = False
except ImportError:
skip_tls_tests = True
# In case the tests are skipped because of lacking test harness, get_cert should be
# overridden to avoid get_cert failing during test discovery as get_cert is used
# in conf_vars decorator
get_cert = lambda x: x

DEFAULT_DATE = timezone.datetime(2017, 1, 1)
SUCCESS_COMMAND = ['airflow', 'tasks', 'run', '--help']
FAIL_COMMAND = ['airflow', 'tasks', 'run', 'false']

# For now we are temporarily removing Dask support until we get Dask Team help us in making the
# tests pass again
skip_dask_tests = False


@pytest.mark.skipif(skip_dask_tests, reason="The tests are skipped because it needs testing from Dask team")
class TestBaseDask(unittest.TestCase):
def assert_tasks_on_executor(self, executor):
def assert_tasks_on_executor(self, executor, timeout_executor=120):

# start the executor
executor.start()
Expand All @@ -58,7 +66,7 @@ def assert_tasks_on_executor(self, executor):
fail_future = next(k for k, v in executor.futures.items() if v == 'fail')

# wait for the futures to execute, with a timeout
timeout = timezone.utcnow() + timedelta(seconds=30)
timeout = timezone.utcnow() + timedelta(seconds=timeout_executor)
while not (success_future.done() and fail_future.done()):
if timezone.utcnow() > timeout:
raise ValueError(
Expand All @@ -75,14 +83,15 @@ def assert_tasks_on_executor(self, executor):
assert fail_future.exception() is not None


@pytest.mark.skipif(skip_dask_tests, reason="The tests are skipped because it needs testing from Dask team")
class TestDaskExecutor(TestBaseDask):
def setUp(self):
self.dagbag = DagBag(include_examples=True)
self.cluster = LocalCluster()

def test_dask_executor_functions(self):
executor = DaskExecutor(cluster_address=self.cluster.scheduler_address)
self.assert_tasks_on_executor(executor)
self.assert_tasks_on_executor(executor, timeout_executor=120)

def test_backfill_integration(self):
"""
Expand Down Expand Up @@ -112,9 +121,9 @@ def setUp(self):

@conf_vars(
{
('dask', 'tls_ca'): get_cert('tls-ca-cert.pem'),
('dask', 'tls_cert'): get_cert('tls-key-cert.pem'),
('dask', 'tls_key'): get_cert('tls-key.pem'),
('dask', 'tls_ca'): 'certs/tls-ca-cert.pem',
('dask', 'tls_cert'): 'certs/tls-key-cert.pem',
('dask', 'tls_key'): 'certs/tls-key.pem',
}
)
def test_tls(self):
Expand All @@ -127,7 +136,7 @@ def test_tls(self):

executor = DaskExecutor(cluster_address=cluster['address'])

self.assert_tasks_on_executor(executor)
self.assert_tasks_on_executor(executor, timeout_executor=120)

executor.end()
# close the executor, the cluster context manager expects all listeners
Expand All @@ -148,6 +157,7 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock
mock_stats_gauge.assert_has_calls(calls)


@pytest.mark.skipif(skip_dask_tests, reason="The tests are skipped because it needs testing from Dask team")
class TestDaskExecutorQueue(unittest.TestCase):
def test_dask_queues_no_resources(self):
self.cluster = LocalCluster()
Expand Down Expand Up @@ -175,7 +185,7 @@ def test_dask_queues(self):
success_future = next(k for k, v in executor.futures.items() if v == 'success')

# wait for the futures to execute, with a timeout
timeout = timezone.utcnow() + timedelta(seconds=30)
timeout = timezone.utcnow() + timedelta(seconds=120)
while not success_future.done():
if timezone.utcnow() > timeout:
raise ValueError(
Expand Down

0 comments on commit e164cf9

Please sign in to comment.