From a234f53b761bc5a116e9ba974b233b8c785a32aa Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Wed, 5 May 2021 15:41:55 +0300 Subject: [PATCH] azure: use a custom event loop for authentication (#5958) * azure: use a custom event loop for authentication * add test_temp_event_loop * wrap open() too! * ignore aiohttp warnings for normal level use-cases * rename tests/unit/remote/test_azure to tests/unit/fs/test_azure --- dvc/fs/azure.py | 47 ++++++++++++++++++++++--- dvc/logger.py | 12 ++++++- tests/unit/{remote => fs}/test_azure.py | 32 ++++++++++++++++- 3 files changed, 84 insertions(+), 7 deletions(-) rename tests/unit/{remote => fs}/test_azure.py (59%) diff --git a/dvc/fs/azure.py b/dvc/fs/azure.py index 3b3b4ad77b..a25f7a397f 100644 --- a/dvc/fs/azure.py +++ b/dvc/fs/azure.py @@ -1,6 +1,8 @@ +import asyncio import logging import os import threading +from contextlib import contextmanager from funcy import cached_property, wrap_prop @@ -19,6 +21,34 @@ ) +@contextmanager +def _temp_event_loop(): + """When trying to initalize azure filesystem instances + with DefaultCredentials, the authentication process requires + to have an access to a separate event loop. The normal calls + are run in a separate loop managed by the fsspec, but the + DefaultCredentials assumes that the callee is managing their + own event loop. This function checks whether is there any + event loop set, and if not it creates a temporary event loop + and set it. After the context is finalized, it restores the + original event loop back (if there is any).""" + + try: + original_loop = asyncio.get_event_loop() + except RuntimeError: + original_loop = None + + loop = original_loop or asyncio.new_event_loop() + + try: + asyncio.set_event_loop(loop) + yield + finally: + if original_loop is None: + loop.close() + asyncio.set_event_loop(original_loop) + + class AzureAuthError(DvcException): pass @@ -120,11 +150,12 @@ def fs(self): from azure.core.exceptions import AzureError try: - file_system = AzureBlobFileSystem(**self.fs_args) - if self.bucket not in [ - container.rstrip("/") for container in file_system.ls("/") - ]: - file_system.mkdir(self.bucket) + with _temp_event_loop(): + file_system = AzureBlobFileSystem(**self.fs_args) + if self.bucket not in [ + container.rstrip("/") for container in file_system.ls("/") + ]: + file_system.mkdir(self.bucket) except (ValueError, AzureError) as e: raise AzureAuthError( f"Authentication to Azure Blob Storage via {self.login_method}" @@ -133,3 +164,9 @@ def fs(self): ) from e return file_system + + def open( + self, path_info, mode="r", **kwargs + ): # pylint: disable=arguments-differ + with _temp_event_loop(): + return self.fs.open(self._with_bucket(path_info), mode=mode) diff --git a/dvc/logger.py b/dvc/logger.py index e2eb5cd8f8..3de943115a 100644 --- a/dvc/logger.py +++ b/dvc/logger.py @@ -196,7 +196,17 @@ def disable_other_loggers(): def setup(level=logging.INFO): colorama.init() - logging.getLogger("asyncio").setLevel(logging.CRITICAL) + + if level >= logging.DEBUG: + # Unclosed session errors for asyncio/aiohttp are only available + # on the tracing mode for extensive debug purposes. They are really + # noisy, and this is potentially somewhere in the client library + # not managing their own session. Even though it is the best practice + # for them to do so, we can be assured that these errors raised when + # the object is getting deallocated, so no need to take any extensive + # action. + logging.getLogger("asyncio").setLevel(logging.CRITICAL) + logging.getLogger("aiohttp").setLevel(logging.CRITICAL) addLoggingLevel("TRACE", logging.DEBUG - 5) logging.config.dictConfig( diff --git a/tests/unit/remote/test_azure.py b/tests/unit/fs/test_azure.py similarity index 59% rename from tests/unit/remote/test_azure.py rename to tests/unit/fs/test_azure.py index 34bddd0d1e..764d426139 100644 --- a/tests/unit/remote/test_azure.py +++ b/tests/unit/fs/test_azure.py @@ -1,4 +1,9 @@ -from dvc.fs.azure import AzureFileSystem +import asyncio +from concurrent.futures import ThreadPoolExecutor + +import pytest + +from dvc.fs.azure import AzureFileSystem, _temp_event_loop from dvc.path_info import PathInfo container_name = "container-name" @@ -37,3 +42,28 @@ def test_info(tmp_dir, azure): hash_ = fs.info(to_info)["etag"] assert isinstance(hash_, str) assert hash_.strip("'").strip('"') == hash_ + + +def test_temp_event_loop(): + def procedure(): + loop = asyncio.get_event_loop() + loop.run_until_complete(asyncio.sleep(0)) + return "yeey" + + def wrapped_procedure(): + with _temp_event_loop(): + return procedure() + + # it should clean the loop after + # exitting the context. + with pytest.raises(RuntimeError): + asyncio.get_event_loop() + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(procedure) + + with pytest.raises(RuntimeError): + future.result() + + future = executor.submit(wrapped_procedure) + assert future.result() == "yeey"