diff --git a/parfive/downloader.py b/parfive/downloader.py index 586f123..29c3bba 100644 --- a/parfive/downloader.py +++ b/parfive/downloader.py @@ -3,6 +3,8 @@ import asyncio import logging import pathlib +import warnings +import threading import contextlib import urllib.parse from typing import Union, Callable, Optional @@ -223,6 +225,14 @@ def filepath(url, resp): def _add_shutdown_signals(loop, task): if os.name == "nt": return + + if threading.current_thread() != threading.main_thread(): + warnings.warn( + "This download has been started in a thread which is not the main thread. You will not be able to interrupt the download.", + UserWarning, + ) + return + for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, task.cancel) diff --git a/parfive/tests/test_downloader.py b/parfive/tests/test_downloader.py index 602e556..58e8e6e 100644 --- a/parfive/tests/test_downloader.py +++ b/parfive/tests/test_downloader.py @@ -1,5 +1,6 @@ import os import platform +import threading from pathlib import Path from unittest import mock from unittest.mock import patch @@ -463,3 +464,34 @@ def test_proxy_passed_as_kwargs_to_get(tmpdir, url, proxy): "proxy": proxy, }, ] + + +class CustomThread(threading.Thread): + def __init__(self, *args, **kwargs): + self.result = None + super().__init__(*args, **kwargs) + + def run(self): + self.result = self._target(*self._args, **self._kwargs) + + +@skip_windows +def test_download_out_of_main_thread(httpserver, tmpdir): + tmpdir = str(tmpdir) + httpserver.serve_content( + "SIMPLE = T", headers={"Content-Disposition": "attachment; filename=testfile.fits"} + ) + dl = Downloader() + + dl.enqueue_file(httpserver.url, path=Path(tmpdir), max_splits=None) + + thread = CustomThread(target=dl.download) + thread.start() + + with pytest.warns( + UserWarning, + match="This download has been started in a thread which is not the main thread. You will not be able to interrupt the download.", + ): + thread.join() + + validate_test_file(thread.result)