From 3b8f7d8f1e14f67550ed191ba70d856f4e36db12 Mon Sep 17 00:00:00 2001 From: Adrian Stachlewski Date: Fri, 10 Jun 2022 19:35:28 +0200 Subject: [PATCH] Close requests.Socket in RemoteScheduler before exiting (#3173) --- luigi/interface.py | 2 ++ luigi/rpc.py | 30 +++++++++++++++++++++++++----- test/rpc_test.py | 5 ++--- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/luigi/interface.py b/luigi/interface.py index 0f51b07783..5fd33b947e 100644 --- a/luigi/interface.py +++ b/luigi/interface.py @@ -173,6 +173,8 @@ def _schedule_and_run(tasks, worker_scheduler_factory=None, override_defaults=No success &= worker.run() luigi_run_result = LuigiRunResult(worker, success) logger.info(luigi_run_result.summary_text) + if hasattr(sch, 'close'): + sch.close() return luigi_run_result diff --git a/luigi/rpc.py b/luigi/rpc.py index e30146c7bd..2066cbbc62 100644 --- a/luigi/rpc.py +++ b/luigi/rpc.py @@ -19,6 +19,7 @@ rpc.py implements the client side of it, server.py implements the server side. See :doc:`/central_scheduler` for more info. """ +import abc import os import json import logging @@ -69,7 +70,17 @@ def __init__(self, message, sub_exception=None): self.sub_exception = sub_exception -class URLLibFetcher: +class _FetcherInterface(metaclass=abc.ABCMeta): + @abc.abstractmethod + def fetch(self, full_url, body, timeout): + pass + + @abc.abstractmethod + def close(self): + pass + + +class URLLibFetcher(_FetcherInterface): raises = (URLError, socket.timeout) def _create_request(self, full_url, body=None): @@ -96,12 +107,15 @@ def fetch(self, full_url, body, timeout): req = self._create_request(full_url, body=body) return urlopen(req, timeout=timeout).read().decode('utf-8') + def close(self): + pass -class RequestsFetcher: - def __init__(self, session): + +class RequestsFetcher(_FetcherInterface): + def __init__(self): from requests import exceptions as requests_exceptions self.raises = requests_exceptions.RequestException - self.session = session + self.session = requests.Session() self.process_id = os.getpid() def check_pid(self): @@ -117,6 +131,9 @@ def fetch(self, full_url, body, timeout): resp.raise_for_status() return resp.text + def close(self): + self.session.close() + class RemoteScheduler: """ @@ -140,10 +157,13 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None): self._rpc_log_retries = config.getboolean('core', 'rpc-log-retries', True) if HAS_REQUESTS: - self._fetcher = RequestsFetcher(requests.Session()) + self._fetcher = RequestsFetcher() else: self._fetcher = URLLibFetcher() + def close(self): + self._fetcher.close() + def _get_retryer(self): def retry_logging(retry_state): if self._rpc_log_retries: diff --git a/test/rpc_test.py b/test/rpc_test.py index 1537f5c9e2..d99152a8c4 100644 --- a/test/rpc_test.py +++ b/test/rpc_test.py @@ -27,7 +27,6 @@ from server_test import ServerTestBase import socket from multiprocessing import Process, Queue -import requests class RemoteSchedulerTest(unittest.TestCase): @@ -147,8 +146,8 @@ def test_get_work_speed(self): class RequestsFetcherTest(ServerTestBase): def test_fork_changes_session(self): - session = requests.Session() - fetcher = luigi.rpc.RequestsFetcher(session) + fetcher = luigi.rpc.RequestsFetcher() + session = fetcher.session q = Queue()